Add basic windows support

This commit is contained in:
Fredrik Enestad 2024-08-01 12:48:24 +02:00
parent 8af7348958
commit dcc761b141
No known key found for this signature in database
33 changed files with 684 additions and 81 deletions

View file

@ -35,7 +35,6 @@ http = { workspace = true }
log = { workspace = true }
h2 = { workspace = true }
lru = { workspace = true }
nix = "~0.24.3"
clap = { version = "3.2.25", features = ["derive"] }
once_cell = { workspace = true }
serde = { version = "1.0", features = ["derive"] }
@ -46,7 +45,6 @@ libc = "0.2.70"
chrono = { version = "~0.4.31", features = ["alloc"], default-features = false }
thread_local = "1.0"
prometheus = "0.13"
daemonize = "0.5.0"
sentry = { version = "0.26", features = [
"backtrace",
"contexts",
@ -69,16 +67,25 @@ tokio-test = "0.4"
zstd = "0"
httpdate = "1"
[target.'cfg(unix)'.dependencies]
daemonize = "0.5.0"
nix = "~0.24.3"
[target.'cfg(windows)'.dependencies]
windows-sys = { version = "0.59.0", features = ["Win32_Networking_WinSock"] }
[target.'cfg(unix)'.dev-dependencies]
hyperlocal = "0.8"
jemallocator = "0.5"
[dev-dependencies]
matches = "0.1"
env_logger = "0.9"
reqwest = { version = "0.11", features = ["rustls"], default-features = false }
hyperlocal = "0.8"
hyper = "0.14"
jemallocator = "0.5"
[features]
default = ["openssl"]
openssl = ["pingora-openssl"]
boringssl = ["pingora-boringssl"]
patched_http1 = []
patched_http1 = []

View file

@ -16,7 +16,7 @@ use super::HttpSession;
use crate::connectors::{ConnectorOptions, TransportConnector};
use crate::protocols::http::v1::client::HttpSession as Http1Session;
use crate::protocols::http::v2::client::{drive_connection, Http2Session};
use crate::protocols::{Digest, Stream};
use crate::protocols::{Digest, Stream, UniqueIDType};
use crate::upstreams::peer::{Peer, ALPN};
use bytes::Bytes;
@ -47,7 +47,7 @@ pub(crate) struct ConnectionRefInner {
connection_stub: Stub,
closed: watch::Receiver<bool>,
ping_timeout_occurred: Arc<AtomicBool>,
id: i32,
id: UniqueIDType,
// max concurrent streams this connection is allowed to create
max_streams: usize,
// how many concurrent streams already active
@ -69,7 +69,7 @@ impl ConnectionRef {
send_req: SendRequest<Bytes>,
closed: watch::Receiver<bool>,
ping_timeout_occurred: Arc<AtomicBool>,
id: i32,
id: UniqueIDType,
max_streams: usize,
digest: Digest,
) -> Self {
@ -98,7 +98,7 @@ impl ConnectionRef {
self.0.current_streams.fetch_sub(1, Ordering::SeqCst);
}
pub fn id(&self) -> i32 {
pub fn id(&self) -> UniqueIDType {
self.0.id
}
@ -196,7 +196,7 @@ impl InUsePool {
// release a h2_stream, this functional will cause an ConnectionRef to be returned (if exist)
// the caller should update the ref and then decide where to put it (in use pool or idle)
fn release(&self, reuse_hash: u64, id: i32) -> Option<ConnectionRef> {
fn release(&self, reuse_hash: u64, id: UniqueIDType) -> Option<ConnectionRef> {
let pools = self.pools.read();
if let Some(pool) = pools.get(&reuse_hash) {
pool.remove(id)

View file

@ -17,10 +17,15 @@ use log::debug;
use pingora_error::{Context, Error, ErrorType::*, OrErr, Result};
use rand::seq::SliceRandom;
use std::net::SocketAddr as InetSocketAddr;
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;
#[cfg(unix)]
use crate::protocols::l4::ext::connect_uds;
use crate::protocols::l4::ext::{
connect_uds, connect_with as tcp_connect, set_dscp, set_recv_buf, set_tcp_fastopen_connect,
connect_with as tcp_connect, set_dscp, set_recv_buf, set_tcp_fastopen_connect,
};
use crate::protocols::l4::socket::SocketAddr;
use crate::protocols::l4::stream::Stream;
@ -39,9 +44,12 @@ where
P: Peer + Send + Sync,
{
if peer.get_proxy().is_some() {
#[cfg(unix)]
return proxy_connect(peer)
.await
.err_context(|| format!("Fail to establish CONNECT proxy: {}", peer));
#[cfg(windows)]
panic!("peer proxy not supported on windows")
}
let peer_addr = peer.address();
let mut stream: Stream =
@ -51,16 +59,21 @@ where
match peer_addr {
SocketAddr::Inet(addr) => {
let connect_future = tcp_connect(addr, bind_to.as_ref(), |socket| {
#[cfg(unix)]
let raw = socket.as_raw_fd();
#[cfg(windows)]
let raw = socket.as_raw_socket();
if peer.tcp_fast_open() {
set_tcp_fastopen_connect(socket.as_raw_fd())?;
set_tcp_fastopen_connect(raw)?;
}
if let Some(recv_buf) = peer.tcp_recv_buf() {
debug!("Setting recv buf size");
set_recv_buf(socket.as_raw_fd(), recv_buf)?;
set_recv_buf(raw, recv_buf)?;
}
if let Some(dscp) = peer.dscp() {
debug!("Setting dscp");
set_dscp(socket.as_raw_fd(), dscp)?;
set_dscp(raw, dscp)?;
}
Ok(())
});
@ -86,6 +99,7 @@ where
}
}
}
#[cfg(unix)]
SocketAddr::Unix(addr) => {
let connect_future = connect_uds(
addr.as_pathname()
@ -128,7 +142,10 @@ where
}
stream.set_nodelay()?;
#[cfg(unix)]
let digest = SocketDigest::from_raw_fd(stream.as_raw_fd());
#[cfg(windows)]
let digest = SocketDigest::from_raw_socket(stream.as_raw_socket());
digest
.peer_addr
.set(Some(peer_addr.clone()))
@ -164,12 +181,14 @@ pub(crate) fn bind_to_random<P: Peer>(
InetSocketAddr::V4(_) => bind_to_ips(v4_list),
InetSocketAddr::V6(_) => bind_to_ips(v6_list),
},
#[cfg(unix)]
SocketAddr::Unix(_) => None,
}
}
use crate::protocols::raw_connect;
#[cfg(unix)]
async fn proxy_connect<P: Peer>(peer: &P) -> Result<Stream> {
// safe to unwrap
let proxy = peer.get_proxy().unwrap();
@ -217,6 +236,7 @@ mod tests {
use std::collections::BTreeMap;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
#[cfg(unix)]
use tokio::net::UnixListener;
#[tokio::test]
@ -285,6 +305,7 @@ mod tests {
assert!(new_session.is_ok());
}
#[cfg(unix)]
#[tokio::test]
async fn test_connect_proxy_fail() {
let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());
@ -302,9 +323,11 @@ mod tests {
assert!(!e.retry());
}
#[cfg(unix)]
const MOCK_UDS_PATH: &str = "/tmp/test_unix_connect_proxy.sock";
// one-off mock server
#[cfg(unix)]
async fn mock_connect_server() {
let _ = std::fs::remove_file(MOCK_UDS_PATH);
let listener = UnixListener::bind(MOCK_UDS_PATH).unwrap();
@ -316,6 +339,7 @@ mod tests {
let _ = std::fs::remove_file(MOCK_UDS_PATH);
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread")]
async fn test_connect_proxy_work() {
tokio::spawn(async {
@ -336,10 +360,12 @@ mod tests {
assert!(new_session.is_ok());
}
#[cfg(unix)]
const MOCK_BAD_UDS_PATH: &str = "/tmp/test_unix_bad_connect_proxy.sock";
// one-off mock bad proxy
// closes connection upon accepting
#[cfg(unix)]
async fn mock_connect_bad_server() {
let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH);
let listener = UnixListener::bind(MOCK_BAD_UDS_PATH).unwrap();
@ -350,6 +376,7 @@ mod tests {
let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH);
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread")]
async fn test_connect_proxy_conn_closed() {
tokio::spawn(async {

View file

@ -195,11 +195,29 @@ impl TransportConnector {
let mut stream = l.into_inner();
// test_reusable_stream: we assume server would never actively send data
// first on an idle stream.
#[cfg(unix)]
if peer.matches_fd(stream.id()) && test_reusable_stream(&mut stream) {
Some(stream)
} else {
None
}
#[cfg(windows)]
{
use std::os::windows::io::{AsRawSocket, RawSocket};
struct WrappedRawSocket(RawSocket);
impl AsRawSocket for WrappedRawSocket {
fn as_raw_socket(&self) -> RawSocket {
self.0
}
}
if peer.matches_sock(WrappedRawSocket(stream.id() as RawSocket))
&& test_reusable_stream(&mut stream)
{
Some(stream)
} else {
None
}
}
}
Err(_) => {
error!("failed to acquire reusable stream");
@ -372,6 +390,7 @@ mod tests {
use crate::tls::ssl::SslMethod;
use crate::upstreams::peer::BasicPeer;
use tokio::io::AsyncWriteExt;
#[cfg(unix)]
use tokio::net::UnixListener;
// 192.0.2.1 is effectively a black hole
@ -403,9 +422,11 @@ mod tests {
assert!(reused);
}
#[cfg(unix)]
const MOCK_UDS_PATH: &str = "/tmp/test_unix_transport_connector.sock";
// one-off mock server
#[cfg(unix)]
async fn mock_connect_server() {
let _ = std::fs::remove_file(MOCK_UDS_PATH);
let listener = UnixListener::bind(MOCK_UDS_PATH).unwrap();
@ -416,6 +437,7 @@ mod tests {
}
let _ = std::fs::remove_file(MOCK_UDS_PATH);
}
#[cfg(unix)]
#[tokio::test(flavor = "multi_thread")]
async fn test_connect_uds() {
tokio::spawn(async {

View file

@ -20,8 +20,12 @@ use pingora_error::{
use std::fs::Permissions;
use std::io::ErrorKind;
use std::net::{SocketAddr, ToSocketAddrs};
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, FromRawFd};
#[cfg(unix)]
use std::os::unix::net::UnixListener as StdUnixListener;
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, FromRawSocket};
use std::time::Duration;
use tokio::net::TcpSocket;
@ -29,6 +33,7 @@ use crate::protocols::l4::ext::{set_dscp, set_tcp_fastopen_backlog};
use crate::protocols::l4::listener::Listener;
pub use crate::protocols::l4::stream::Stream;
use crate::protocols::TcpKeepalive;
#[cfg(unix)]
use crate::server::ListenFds;
const TCP_LISTENER_MAX_TRY: usize = 30;
@ -40,6 +45,7 @@ const LISTENER_BACKLOG: u32 = 65535;
#[derive(Clone, Debug)]
pub enum ServerAddress {
Tcp(String, Option<TcpSocketOptions>),
#[cfg(unix)]
Uds(String, Option<Permissions>),
}
@ -47,6 +53,7 @@ impl AsRef<str> for ServerAddress {
fn as_ref(&self) -> &str {
match &self {
Self::Tcp(l, _) => l,
#[cfg(unix)]
Self::Uds(l, _) => l,
}
}
@ -82,6 +89,7 @@ pub struct TcpSocketOptions {
// TODO: allow configuring reuseaddr, backlog, etc. from here?
}
#[cfg(unix)]
mod uds {
use super::{OrErr, Result};
use crate::protocols::l4::listener::Listener;
@ -151,17 +159,24 @@ fn apply_tcp_socket_options(sock: &TcpSocket, opt: Option<&TcpSocketOptions>) ->
}
if let Some(backlog) = opt.tcp_fastopen {
#[cfg(unix)]
set_tcp_fastopen_backlog(sock.as_raw_fd(), backlog)?;
#[cfg(windows)]
set_tcp_fastopen_backlog(sock.as_raw_socket(), backlog)?;
}
if let Some(dscp) = opt.dscp {
#[cfg(unix)]
set_dscp(sock.as_raw_fd(), dscp)?;
#[cfg(windows)]
set_dscp(sock.as_raw_socket(), dscp)?;
}
Ok(())
}
fn from_raw_fd(address: &ServerAddress, fd: i32) -> Result<Listener> {
match address {
#[cfg(unix)]
ServerAddress::Uds(addr, perm) => {
let std_listener = unsafe { StdUnixListener::from_raw_fd(fd) };
// set permissions just in case
@ -169,7 +184,10 @@ fn from_raw_fd(address: &ServerAddress, fd: i32) -> Result<Listener> {
Ok(uds::set_backlog(std_listener, LISTENER_BACKLOG)?.into())
}
ServerAddress::Tcp(_, _) => {
#[cfg(unix)]
let std_listener_socket = unsafe { std::net::TcpStream::from_raw_fd(fd) };
#[cfg(windows)]
let std_listener_socket = unsafe { std::net::TcpStream::from_raw_socket(fd as u64) };
let listener_socket = TcpSocket::from_std_stream(std_listener_socket);
// Note that we call listen on an already listening socket
// POSIX undefined but on Linux it will update the backlog size
@ -231,6 +249,7 @@ async fn bind_tcp(addr: &str, opt: Option<TcpSocketOptions>) -> Result<Listener>
async fn bind(addr: &ServerAddress) -> Result<Listener> {
match addr {
#[cfg(unix)]
ServerAddress::Uds(l, perm) => uds::bind(l, perm.clone()),
ServerAddress::Tcp(l, opt) => bind_tcp(l, opt.clone()).await,
}
@ -253,6 +272,7 @@ impl ListenerEndpoint {
self.listen_addr.as_ref()
}
#[cfg(unix)]
pub async fn listen(&mut self, fds: Option<ListenFds>) -> Result<()> {
if self.listener.is_some() {
return Ok(());
@ -278,6 +298,12 @@ impl ListenerEndpoint {
Ok(())
}
#[cfg(windows)]
pub async fn listen(&mut self) -> Result<()> {
self.listener = Some(bind(&self.listen_addr).await?);
Ok(())
}
fn apply_stream_settings(&self, stream: &mut Stream) -> Result<()> {
// settings are applied based on whether the underlying stream supports it
stream.set_nodelay()?;
@ -288,7 +314,10 @@ impl ListenerEndpoint {
stream.set_keepalive(ka)?;
}
if let Some(dscp) = op.dscp {
#[cfg(unix)]
set_dscp(stream.as_raw_fd(), dscp)?;
#[cfg(windows)]
set_dscp(stream.as_raw_socket(), dscp)?;
}
Ok(())
}
@ -315,7 +344,13 @@ mod test {
async fn test_listen_tcp() {
let addr = "127.0.0.1:7100";
let mut listener = ListenerEndpoint::new(ServerAddress::Tcp(addr.into(), None));
listener.listen(None).await.unwrap();
listener
.listen(
#[cfg(unix)]
None,
)
.await
.unwrap();
tokio::spawn(async move {
// just try to accept once
listener.accept().await.unwrap();
@ -332,7 +367,13 @@ mod test {
..Default::default()
});
let mut listener = ListenerEndpoint::new(ServerAddress::Tcp("[::]:7101".into(), sock_opt));
listener.listen(None).await.unwrap();
listener
.listen(
#[cfg(unix)]
None,
)
.await
.unwrap();
tokio::spawn(async move {
// just try to accept twice
listener.accept().await.unwrap();
@ -346,11 +387,18 @@ mod test {
.expect("can connect to v6 addr");
}
#[cfg(unix)]
#[tokio::test]
async fn test_listen_uds() {
let addr = "/tmp/test_listen_uds";
let mut listener = ListenerEndpoint::new(ServerAddress::Uds(addr.into(), None));
listener.listen(None).await.unwrap();
listener
.listen(
#[cfg(unix)]
None,
)
.await
.unwrap();
tokio::spawn(async move {
// just try to accept once
listener.accept().await.unwrap();

View file

@ -18,6 +18,7 @@ mod l4;
mod tls;
use crate::protocols::Stream;
#[cfg(unix)]
use crate::server::ListenFds;
use pingora_error::Result;
@ -36,10 +37,11 @@ struct TransportStackBuilder {
}
impl TransportStackBuilder {
pub fn build(&mut self, upgrade_listeners: Option<ListenFds>) -> TransportStack {
pub fn build(&mut self, #[cfg(unix)] upgrade_listeners: Option<ListenFds>) -> TransportStack {
TransportStack {
l4: ListenerEndpoint::new(self.l4.clone()),
tls: self.tls.take().map(|tls| Arc::new(tls.build())),
#[cfg(unix)]
upgrade_listeners,
}
}
@ -49,6 +51,7 @@ pub(crate) struct TransportStack {
l4: ListenerEndpoint,
tls: Option<Arc<Acceptor>>,
// listeners sent from the old process for graceful upgrade
#[cfg(unix)]
upgrade_listeners: Option<ListenFds>,
}
@ -58,7 +61,12 @@ impl TransportStack {
}
pub async fn listen(&mut self) -> Result<()> {
self.l4.listen(self.upgrade_listeners.take()).await
self.l4
.listen(
#[cfg(unix)]
self.upgrade_listeners.take(),
)
.await
}
pub async fn accept(&mut self) -> Result<UninitializedStream> {
@ -109,6 +117,7 @@ impl Listeners {
}
/// Create a new [`Listeners`] with a Unix domain socket endpoint from the given string.
#[cfg(unix)]
pub fn uds(addr: &str, perm: Option<Permissions>) -> Self {
let mut listeners = Self::new();
listeners.add_uds(addr, perm);
@ -136,6 +145,7 @@ impl Listeners {
}
/// Add a Unix domain socket endpoint to `self`.
#[cfg(unix)]
pub fn add_uds(&mut self, addr: &str, perm: Option<Permissions>) {
self.add_address(ServerAddress::Uds(addr.into(), perm));
}
@ -168,10 +178,18 @@ impl Listeners {
self.stacks.push(TransportStackBuilder { l4, tls })
}
pub(crate) fn build(&mut self, upgrade_listeners: Option<ListenFds>) -> Vec<TransportStack> {
pub(crate) fn build(
&mut self,
#[cfg(unix)] upgrade_listeners: Option<ListenFds>,
) -> Vec<TransportStack> {
self.stacks
.iter_mut()
.map(|b| b.build(upgrade_listeners.clone()))
.map(|b| {
b.build(
#[cfg(unix)]
upgrade_listeners.clone(),
)
})
.collect()
}
@ -194,7 +212,10 @@ mod test {
let mut listeners = Listeners::tcp(addr1);
listeners.add_tcp(addr2);
let listeners = listeners.build(None);
let listeners = listeners.build(
#[cfg(unix)]
None,
);
assert_eq!(listeners.len(), 2);
for mut listener in listeners {
tokio::spawn(async move {
@ -220,7 +241,13 @@ mod test {
let cert_path = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR"));
let key_path = format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR"));
let mut listeners = Listeners::tls(addr, &cert_path, &key_path).unwrap();
let mut listener = listeners.build(None).pop().unwrap();
let mut listener = listeners
.build(
#[cfg(unix)]
None,
)
.pop()
.unwrap();
tokio::spawn(async move {
listener.listen().await.unwrap();

View file

@ -62,7 +62,10 @@ impl Default for TimingDigest {
#[derive(Debug)]
/// The interface to return socket-related information
pub struct SocketDigest {
#[cfg(unix)]
raw_fd: std::os::unix::io::RawFd,
#[cfg(windows)]
raw_sock: std::os::windows::io::RawSocket,
/// Remote socket address
pub peer_addr: OnceCell<Option<SocketAddr>>,
/// Local socket address
@ -70,6 +73,7 @@ pub struct SocketDigest {
}
impl SocketDigest {
#[cfg(unix)]
pub fn from_raw_fd(raw_fd: std::os::unix::io::RawFd) -> SocketDigest {
SocketDigest {
raw_fd,
@ -78,22 +82,47 @@ impl SocketDigest {
}
}
#[cfg(windows)]
pub fn from_raw_socket(raw_sock: std::os::windows::io::RawSocket) -> SocketDigest {
SocketDigest {
raw_sock,
peer_addr: OnceCell::new(),
local_addr: OnceCell::new(),
}
}
#[cfg(unix)]
pub fn peer_addr(&self) -> Option<&SocketAddr> {
self.peer_addr
.get_or_init(|| SocketAddr::from_raw_fd(self.raw_fd, true))
.as_ref()
}
#[cfg(windows)]
pub fn peer_addr(&self) -> Option<&SocketAddr> {
self.peer_addr
.get_or_init(|| SocketAddr::from_raw_socket(self.raw_sock, true))
.as_ref()
}
#[cfg(unix)]
pub fn local_addr(&self) -> Option<&SocketAddr> {
self.local_addr
.get_or_init(|| SocketAddr::from_raw_fd(self.raw_fd, false))
.as_ref()
}
#[cfg(windows)]
pub fn local_addr(&self) -> Option<&SocketAddr> {
self.local_addr
.get_or_init(|| SocketAddr::from_raw_socket(self.raw_sock, false))
.as_ref()
}
fn is_inet(&self) -> bool {
self.local_addr().and_then(|p| p.as_inet()).is_some()
}
#[cfg(unix)]
pub fn tcp_info(&self) -> Option<TCP_INFO> {
if self.is_inet() {
get_tcp_info(self.raw_fd).ok()
@ -102,6 +131,16 @@ impl SocketDigest {
}
}
#[cfg(windows)]
pub fn tcp_info(&self) -> Option<TCP_INFO> {
if self.is_inet() {
get_tcp_info(self.raw_sock).ok()
} else {
None
}
}
#[cfg(unix)]
pub fn get_recv_buf(&self) -> Option<usize> {
if self.is_inet() {
get_recv_buf(self.raw_fd).ok()
@ -109,6 +148,15 @@ impl SocketDigest {
None
}
}
#[cfg(windows)]
pub fn get_recv_buf(&self) -> Option<usize> {
if self.is_inet() {
get_recv_buf(self.raw_sock).ok()
} else {
None
}
}
}
/// The interface to return timing information

View file

@ -28,7 +28,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::body::{BodyReader, BodyWriter};
use super::common::*;
use crate::protocols::http::HttpTask;
use crate::protocols::{Digest, SocketAddr, Stream, UniqueID};
use crate::protocols::{Digest, SocketAddr, Stream, UniqueID, UniqueIDType};
use crate::utils::{BufRef, KVRef};
/// The HTTP 1.x client session
@ -717,7 +717,7 @@ pub(crate) fn http_req_header_to_wire(req: &RequestHeader) -> Option<BytesMut> {
}
impl UniqueID for HttpSession {
fn id(&self) -> i32 {
fn id(&self) -> UniqueIDType {
self.underlying_stream.id()
}
}

View file

@ -30,7 +30,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::watch;
use crate::connectors::http::v2::ConnectionRef;
use crate::protocols::{Digest, SocketAddr};
use crate::protocols::{Digest, SocketAddr, UniqueIDType};
pub const PING_TIMEDOUT: ErrorType = ErrorType::new("PingTimedout");
@ -336,7 +336,7 @@ impl Http2Session {
}
/// the FD of the underlying connection
pub fn fd(&self) -> i32 {
pub fn fd(&self) -> UniqueIDType {
self.conn.id()
}
@ -415,7 +415,7 @@ use tokio::sync::oneshot;
pub async fn drive_connection<S>(
mut c: client::Connection<S>,
id: i32,
id: UniqueIDType,
closed: watch::Sender<bool>,
ping_interval: Option<Duration>,
ping_timeout_occurred: Arc<AtomicBool>,
@ -469,7 +469,7 @@ async fn do_ping_pong(
interval: Duration,
tx: oneshot::Sender<()>,
dropped: Arc<AtomicBool>,
id: i32,
id: UniqueIDType,
) {
// delay before sending the first ping, no need to race with the first request
tokio::time::sleep(interval).await;

View file

@ -16,6 +16,7 @@
#![allow(non_camel_case_types)]
#[cfg(unix)]
use libc::socklen_t;
#[cfg(target_os = "linux")]
use libc::{c_int, c_ulonglong, c_void};
@ -23,9 +24,14 @@ use pingora_error::{Error, ErrorType::*, OrErr, Result};
use std::io::{self, ErrorKind};
use std::mem;
use std::net::SocketAddr;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::time::Duration;
use tokio::net::{TcpSocket, TcpStream, UnixStream};
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::net::{TcpSocket, TcpStream};
/// The (copy of) the kernel struct tcp_info returns
#[repr(C)]
@ -96,9 +102,14 @@ impl TCP_INFO {
}
/// Return the size of [`TCP_INFO`]
#[cfg(unix)]
pub fn len() -> socklen_t {
mem::size_of::<Self>() as socklen_t
}
#[cfg(windows)]
pub fn len() -> usize {
mem::size_of::<Self>()
}
}
#[cfg(target_os = "linux")]
@ -165,7 +176,7 @@ fn ip_bind_addr_no_port(fd: RawFd, val: bool) -> io::Result<()> {
set_opt(fd, libc::IPPROTO_IP, IP_BIND_ADDRESS_NO_PORT, val as c_int)
}
#[cfg(not(target_os = "linux"))]
#[cfg(all(unix, not(target_os = "linux")))]
fn ip_bind_addr_no_port(_fd: RawFd, _val: bool) -> io::Result<()> {
Ok(())
}
@ -208,22 +219,32 @@ fn set_keepalive(fd: RawFd, ka: &TcpKeepalive) -> io::Result<()> {
set_so_keepalive_count(fd, ka.count)
}
#[cfg(not(target_os = "linux"))]
#[cfg(all(unix, not(target_os = "linux")))]
fn set_keepalive(_fd: RawFd, _ka: &TcpKeepalive) -> io::Result<()> {
Ok(())
}
#[cfg(windows)]
fn set_keepalive(_sock: RawSocket, _ka: &TcpKeepalive) -> io::Result<()> {
Ok(())
}
/// Get the kernel TCP_INFO for the given FD.
#[cfg(target_os = "linux")]
pub fn get_tcp_info(fd: RawFd) -> io::Result<TCP_INFO> {
get_opt_sized(fd, libc::IPPROTO_TCP, libc::TCP_INFO)
}
#[cfg(not(target_os = "linux"))]
#[cfg(all(unix, not(target_os = "linux")))]
pub fn get_tcp_info(_fd: RawFd) -> io::Result<TCP_INFO> {
Ok(unsafe { TCP_INFO::new() })
}
#[cfg(windows)]
pub fn get_tcp_info(_sock: RawSocket) -> io::Result<TCP_INFO> {
Ok(unsafe { TCP_INFO::new() })
}
/// Set the TCP receive buffer size. See SO_RCVBUF.
#[cfg(target_os = "linux")]
pub fn set_recv_buf(fd: RawFd, val: usize) -> Result<()> {
@ -231,21 +252,31 @@ pub fn set_recv_buf(fd: RawFd, val: usize) -> Result<()> {
.or_err(ConnectError, "failed to set SO_RCVBUF")
}
#[cfg(not(target_os = "linux"))]
#[cfg(all(unix, not(target_os = "linux")))]
pub fn set_recv_buf(_fd: RawFd, _: usize) -> Result<()> {
Ok(())
}
#[cfg(windows)]
pub fn set_recv_buf(_sock: RawSocket, _: usize) -> Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
pub fn get_recv_buf(fd: RawFd) -> io::Result<usize> {
get_opt_sized::<c_int>(fd, libc::SOL_SOCKET, libc::SO_RCVBUF).map(|v| v as usize)
}
#[cfg(not(target_os = "linux"))]
#[cfg(all(unix, not(target_os = "linux")))]
pub fn get_recv_buf(_fd: RawFd) -> io::Result<usize> {
Ok(0)
}
#[cfg(windows)]
pub fn get_recv_buf(_sock: RawSocket) -> io::Result<usize> {
Ok(0)
}
/// Enable client side TCP fast open.
#[cfg(target_os = "linux")]
pub fn set_tcp_fastopen_connect(fd: RawFd) -> Result<()> {
@ -258,11 +289,16 @@ pub fn set_tcp_fastopen_connect(fd: RawFd) -> Result<()> {
.or_err(ConnectError, "failed to set TCP_FASTOPEN_CONNECT")
}
#[cfg(not(target_os = "linux"))]
#[cfg(all(unix, not(target_os = "linux")))]
pub fn set_tcp_fastopen_connect(_fd: RawFd) -> Result<()> {
Ok(())
}
#[cfg(windows)]
pub fn set_tcp_fastopen_connect(_sock: RawSocket) -> Result<()> {
Ok(())
}
/// Enable server side TCP fast open.
#[cfg(target_os = "linux")]
pub fn set_tcp_fastopen_backlog(fd: RawFd, backlog: usize) -> Result<()> {
@ -270,11 +306,15 @@ pub fn set_tcp_fastopen_backlog(fd: RawFd, backlog: usize) -> Result<()> {
.or_err(ConnectError, "failed to set TCP_FASTOPEN")
}
#[cfg(not(target_os = "linux"))]
#[cfg(all(unix, not(target_os = "linux")))]
pub fn set_tcp_fastopen_backlog(_fd: RawFd, _backlog: usize) -> Result<()> {
Ok(())
}
#[cfg(windows)]
pub fn set_tcp_fastopen_backlog(_sock: RawSocket, _backlog: usize) -> Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
pub fn set_dscp(fd: RawFd, value: u8) -> Result<()> {
use super::socket::SocketAddr;
@ -295,17 +335,22 @@ pub fn set_dscp(fd: RawFd, value: u8) -> Result<()> {
}
}
#[cfg(not(target_os = "linux"))]
#[cfg(all(unix, not(target_os = "linux")))]
pub fn set_dscp(_fd: RawFd, _value: u8) -> Result<()> {
Ok(())
}
#[cfg(windows)]
pub fn set_dscp(_sock: RawSocket, _value: u8) -> Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
pub fn get_socket_cookie(fd: RawFd) -> io::Result<u64> {
get_opt_sized::<c_ulonglong>(fd, libc::SOL_SOCKET, libc::SO_COOKIE)
}
#[cfg(not(target_os = "linux"))]
#[cfg(all(unix, not(target_os = "linux")))]
pub fn get_socket_cookie(_fd: RawFd) -> io::Result<u64> {
Ok(0) // SO_COOKIE is a Linux concept
}
@ -327,7 +372,8 @@ pub(crate) async fn connect_with<F: FnOnce(&TcpSocket) -> Result<()>>(
}
.or_err(SocketError, "failed to create socket")?;
if cfg!(target_os = "linux") {
#[cfg(target_os = "linux")]
{
ip_bind_addr_no_port(socket.as_raw_fd(), true)
.or_err(SocketError, "failed to set socket opts")?;
@ -337,6 +383,12 @@ pub(crate) async fn connect_with<F: FnOnce(&TcpSocket) -> Result<()>>(
.or_err_with(BindError, || format!("failed to bind to socket {}", *baddr))?;
};
}
#[cfg(windows)]
if let Some(baddr) = bind_to {
socket
.bind(*baddr)
.or_err_with(BindError, || format!("failed to bind to socket {}", *baddr))?;
};
// TODO: add support for bind on other platforms
set_socket(&socket)?;
@ -355,6 +407,7 @@ pub async fn connect(addr: &SocketAddr, bind_to: Option<&SocketAddr>) -> Result<
}
/// connect() to the given Unix domain socket
#[cfg(unix)]
pub async fn connect_uds(path: &std::path::Path) -> Result<UnixStream> {
UnixStream::connect(path)
.await
@ -396,9 +449,12 @@ impl std::fmt::Display for TcpKeepalive {
/// Apply the given TCP keepalive settings to the given connection
pub fn set_tcp_keepalive(stream: &TcpStream, ka: &TcpKeepalive) -> Result<()> {
let fd = stream.as_raw_fd();
#[cfg(unix)]
let raw = stream.as_raw_fd();
#[cfg(windows)]
let raw = stream.as_raw_socket();
// TODO: check localhost or if keepalive is already set
set_keepalive(fd, ka).or_err(ConnectError, "failed to set keepalive")
set_keepalive(raw, ka).or_err(ConnectError, "failed to set keepalive")
}
#[cfg(test)]
@ -409,7 +465,10 @@ mod test {
fn test_set_recv_buf() {
use tokio::net::TcpSocket;
let socket = TcpSocket::new_v4().unwrap();
#[cfg(unix)]
set_recv_buf(socket.as_raw_fd(), 102400).unwrap();
#[cfg(windows)]
set_recv_buf(socket.as_raw_socket(), 102400).unwrap();
#[cfg(target_os = "linux")]
{

View file

@ -14,17 +14,22 @@
//! Listeners
use std::io;
use std::os::unix::io::AsRawFd;
use tokio::net::{TcpListener, UnixListener};
use crate::protocols::digest::{GetSocketDigest, SocketDigest};
use crate::protocols::l4::stream::Stream;
use std::io;
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;
use tokio::net::TcpListener;
#[cfg(unix)]
use tokio::net::UnixListener;
/// The type for generic listener for both TCP and Unix domain socket
#[derive(Debug)]
pub enum Listener {
Tcp(TcpListener),
#[cfg(unix)]
Unix(UnixListener),
}
@ -34,12 +39,14 @@ impl From<TcpListener> for Listener {
}
}
#[cfg(unix)]
impl From<UnixListener> for Listener {
fn from(s: UnixListener) -> Self {
Self::Unix(s)
}
}
#[cfg(unix)]
impl AsRawFd for Listener {
fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
match &self {
@ -49,13 +56,25 @@ impl AsRawFd for Listener {
}
}
#[cfg(windows)]
impl AsRawSocket for Listener {
fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
match &self {
Self::Tcp(l) => l.as_raw_socket(),
}
}
}
impl Listener {
/// Accept a connection from the listening endpoint
pub async fn accept(&self) -> io::Result<Stream> {
match &self {
Self::Tcp(l) => l.accept().await.map(|(stream, peer_addr)| {
let mut s: Stream = stream.into();
#[cfg(unix)]
let digest = SocketDigest::from_raw_fd(s.as_raw_fd());
#[cfg(windows)]
let digest = SocketDigest::from_raw_socket(s.as_raw_socket());
digest
.peer_addr
.set(Some(peer_addr.into()))
@ -66,6 +85,7 @@ impl Listener {
// and init it in the socket digest here
s
}),
#[cfg(unix)]
Self::Unix(l) => l.accept().await.map(|(stream, peer_addr)| {
let mut s: Stream = stream.into();
let digest = SocketDigest::from_raw_fd(s.as_raw_fd());

View file

@ -16,11 +16,14 @@
use crate::{Error, OrErr};
use log::warn;
#[cfg(unix)]
use nix::sys::socket::{getpeername, getsockname, SockaddrStorage};
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::net::SocketAddr as StdSockAddr;
#[cfg(unix)]
use std::os::unix::net::SocketAddr as StdUnixSockAddr;
#[cfg(unix)]
use tokio::net::unix::SocketAddr as TokioUnixSockAddr;
/// [`SocketAddr`] is a storage type that contains either a Internet (IP address)
@ -28,6 +31,7 @@ use tokio::net::unix::SocketAddr as TokioUnixSockAddr;
#[derive(Debug, Clone)]
pub enum SocketAddr {
Inet(StdSockAddr),
#[cfg(unix)]
Unix(StdUnixSockAddr),
}
@ -42,6 +46,7 @@ impl SocketAddr {
}
/// Get a reference to the Unix domain socket if it is one
#[cfg(unix)]
pub fn as_unix(&self) -> Option<&StdUnixSockAddr> {
if let SocketAddr::Unix(addr) = self {
Some(addr)
@ -57,6 +62,7 @@ impl SocketAddr {
}
}
#[cfg(unix)]
fn from_sockaddr_storage(sock: &SockaddrStorage) -> Option<SocketAddr> {
if let Some(v4) = sock.as_sockaddr_in() {
return Some(SocketAddr::Inet(StdSockAddr::V4(
@ -77,6 +83,7 @@ impl SocketAddr {
))
}
#[cfg(unix)]
pub fn from_raw_fd(fd: std::os::unix::io::RawFd, peer_addr: bool) -> Option<SocketAddr> {
let sockaddr_storage = if peer_addr {
getpeername(fd)
@ -90,12 +97,28 @@ impl SocketAddr {
Err(_e) => None,
}
}
#[cfg(windows)]
pub fn from_raw_socket(
sock: std::os::windows::io::RawSocket,
is_peer_addr: bool,
) -> Option<SocketAddr> {
use crate::protocols::windows::{local_addr, peer_addr};
if is_peer_addr {
peer_addr(sock)
} else {
local_addr(sock)
}
.map(|s| s.into())
.ok()
}
}
impl std::fmt::Display for SocketAddr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
SocketAddr::Inet(addr) => write!(f, "{addr}"),
#[cfg(unix)]
SocketAddr::Unix(addr) => {
if let Some(path) = addr.as_pathname() {
write!(f, "{}", path.display())
@ -111,6 +134,7 @@ impl Hash for SocketAddr {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Self::Inet(sockaddr) => sockaddr.hash(state),
#[cfg(unix)]
Self::Unix(sockaddr) => {
if let Some(path) = sockaddr.as_pathname() {
// use the underlying path as the hash
@ -130,6 +154,7 @@ impl PartialEq for SocketAddr {
fn eq(&self, other: &Self) -> bool {
match self {
Self::Inet(addr) => Some(addr) == other.as_inet(),
#[cfg(unix)]
Self::Unix(addr) => {
let path = addr.as_pathname();
// can only compare UDS with path, assume false on all unnamed UDS
@ -156,6 +181,7 @@ impl Ord for SocketAddr {
Ordering::Less
}
}
#[cfg(unix)]
Self::Unix(addr) => {
if let Some(o) = other.as_unix() {
// NOTE: unnamed UDS are consider the same
@ -175,6 +201,7 @@ impl std::str::FromStr for SocketAddr {
type Err = Box<Error>;
// This is very basic parsing logic, it might treat invalid IP:PORT str as UDS path
#[cfg(unix)]
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.starts_with("unix:") {
// format unix:/tmp/server.socket
@ -195,6 +222,11 @@ impl std::str::FromStr for SocketAddr {
}
}
}
#[cfg(windows)]
fn from_str(s: &str) -> Result<Self, Self::Err> {
let addr = StdSockAddr::from_str(s).or_err(crate::BindError, "invalid socket addr")?;
Ok(SocketAddr::Inet(addr))
}
}
impl std::net::ToSocketAddrs for SocketAddr {
@ -219,6 +251,7 @@ impl From<StdSockAddr> for SocketAddr {
}
}
#[cfg(unix)]
impl From<StdUnixSockAddr> for SocketAddr {
fn from(sockaddr: StdUnixSockAddr) -> Self {
SocketAddr::Unix(sockaddr)
@ -228,6 +261,7 @@ impl From<StdUnixSockAddr> for SocketAddr {
// TODO: ideally mio/tokio will start using the std version of the unix `SocketAddr`
// so we can avoid a fallible conversion
// https://github.com/tokio-rs/mio/issues/1527
#[cfg(unix)]
impl TryFrom<TokioUnixSockAddr> for SocketAddr {
type Error = String;
@ -251,12 +285,14 @@ mod test {
assert!(ip.as_inet().is_some());
}
#[cfg(unix)]
#[test]
fn parse_uds() {
let uds: SocketAddr = "/tmp/my.sock".parse().unwrap();
assert!(uds.as_unix().is_some());
}
#[cfg(unix)]
#[test]
fn parse_uds_with_prefix() {
let uds: SocketAddr = "unix:/tmp/my.sock".parse().unwrap();

View file

@ -18,13 +18,18 @@ use async_trait::async_trait;
use futures::FutureExt;
use log::{debug, error};
use pingora_error::{ErrorType::*, OrErr, Result};
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant, SystemTime};
use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
use tokio::net::{TcpStream, UnixStream};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive};
use crate::protocols::raw_connect::ProxyDigest;
@ -37,6 +42,7 @@ use crate::upstreams::peer::Tracer;
#[derive(Debug)]
enum RawStream {
Tcp(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
}
@ -50,6 +56,7 @@ impl AsyncRead for RawStream {
unsafe {
match &mut Pin::get_unchecked_mut(self) {
RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
#[cfg(unix)]
RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
}
}
@ -62,6 +69,7 @@ impl AsyncWrite for RawStream {
unsafe {
match &mut Pin::get_unchecked_mut(self) {
RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
#[cfg(unix)]
RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
}
}
@ -72,6 +80,7 @@ impl AsyncWrite for RawStream {
unsafe {
match &mut Pin::get_unchecked_mut(self) {
RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
#[cfg(unix)]
RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
}
}
@ -82,6 +91,7 @@ impl AsyncWrite for RawStream {
unsafe {
match &mut Pin::get_unchecked_mut(self) {
RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
#[cfg(unix)]
RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
}
}
@ -96,6 +106,7 @@ impl AsyncWrite for RawStream {
unsafe {
match &mut Pin::get_unchecked_mut(self) {
RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
#[cfg(unix)]
RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
}
}
@ -104,11 +115,13 @@ impl AsyncWrite for RawStream {
fn is_write_vectored(&self) -> bool {
match self {
RawStream::Tcp(s) => s.is_write_vectored(),
#[cfg(unix)]
RawStream::Unix(s) => s.is_write_vectored(),
}
}
}
#[cfg(unix)]
impl AsRawFd for RawStream {
fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
match self {
@ -118,6 +131,15 @@ impl AsRawFd for RawStream {
}
}
#[cfg(windows)]
impl AsRawSocket for RawStream {
fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
match self {
RawStream::Tcp(s) => s.as_raw_socket(),
}
}
}
// Large read buffering helps reducing syscalls with little trade-off
// Ssl layer always does "small" reads in 16k (TLS record size) so L4 read buffer helps a lot.
const BUF_READ_SIZE: usize = 64 * 1024;
@ -180,6 +202,7 @@ impl From<TcpStream> for Stream {
}
}
#[cfg(unix)]
impl From<UnixStream> for Stream {
fn from(s: UnixStream) -> Self {
Stream {
@ -195,18 +218,34 @@ impl From<UnixStream> for Stream {
}
}
#[cfg(unix)]
impl AsRawFd for Stream {
fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
self.stream.get_ref().as_raw_fd()
}
}
#[cfg(windows)]
impl AsRawSocket for Stream {
fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
self.stream.get_ref().as_raw_socket()
}
}
#[cfg(unix)]
impl UniqueID for Stream {
fn id(&self) -> i32 {
self.as_raw_fd()
}
}
#[cfg(windows)]
impl UniqueID for Stream {
fn id(&self) -> usize {
self.as_raw_socket() as usize
}
}
impl Ssl for Stream {}
#[async_trait]
@ -264,6 +303,7 @@ impl Drop for Stream {
/* use nodelay/local_addr function to detect socket status */
let ret = match &self.stream.get_ref() {
RawStream::Tcp(s) => s.nodelay().err(),
#[cfg(unix)]
RawStream::Unix(s) => s.local_addr().err(),
};
if let Some(e) = ret {

View file

@ -19,6 +19,8 @@ pub mod http;
pub mod l4;
pub mod raw_connect;
pub mod ssl;
#[cfg(windows)]
mod windows;
pub use digest::{
Digest, GetProxyDigest, GetSocketDigest, GetTimingDigest, ProtoDigest, SocketDigest,
@ -38,11 +40,16 @@ pub trait Shutdown {
async fn shutdown(&mut self) -> ();
}
#[cfg(unix)]
pub type UniqueIDType = i32;
#[cfg(windows)]
pub type UniqueIDType = usize;
/// Define how a given session/connection identifies itself.
pub trait UniqueID {
/// The ID returned should be unique among all existing connections of the same type.
/// But ID can be recycled after a connection is shutdown.
fn id(&self) -> i32;
fn id(&self) -> UniqueIDType;
}
/// Interface to get TLS info
@ -126,7 +133,7 @@ mod ext_io_impl {
async fn shutdown(&mut self) -> () {}
}
impl UniqueID for Mock {
fn id(&self) -> i32 {
fn id(&self) -> UniqueIDType {
0
}
}
@ -154,7 +161,7 @@ mod ext_io_impl {
async fn shutdown(&mut self) -> () {}
}
impl<T> UniqueID for Cursor<T> {
fn id(&self) -> i32 {
fn id(&self) -> UniqueIDType {
0
}
}
@ -182,7 +189,7 @@ mod ext_io_impl {
async fn shutdown(&mut self) -> () {}
}
impl UniqueID for DuplexStream {
fn id(&self) -> i32 {
fn id(&self) -> UniqueIDType {
0
}
}
@ -204,15 +211,26 @@ mod ext_io_impl {
}
}
#[cfg(unix)]
pub(crate) trait ConnFdReusable {
fn check_fd_match<V: AsRawFd>(&self, fd: V) -> bool;
}
#[cfg(windows)]
pub(crate) trait ConnSockReusable {
fn check_sock_match<V: AsRawSocket>(&self, sock: V) -> bool;
}
use l4::socket::SocketAddr;
use log::{debug, error};
#[cfg(unix)]
use nix::sys::socket::{getpeername, SockaddrStorage, UnixAddr};
use std::{net::SocketAddr as InetSocketAddr, os::unix::prelude::AsRawFd, path::Path};
#[cfg(unix)]
use std::os::unix::prelude::AsRawFd;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;
use std::{net::SocketAddr as InetSocketAddr, path::Path};
#[cfg(unix)]
impl ConnFdReusable for SocketAddr {
fn check_fd_match<V: AsRawFd>(&self, fd: V) -> bool {
match self {
@ -224,7 +242,16 @@ impl ConnFdReusable for SocketAddr {
}
}
}
#[cfg(windows)]
impl ConnSockReusable for SocketAddr {
fn check_sock_match<V: AsRawSocket>(&self, sock: V) -> bool {
match self {
SocketAddr::Inet(addr) => addr.check_sock_match(sock),
}
}
}
#[cfg(unix)]
impl ConnFdReusable for Path {
fn check_fd_match<V: AsRawFd>(&self, fd: V) -> bool {
let fd = fd.as_raw_fd();
@ -252,6 +279,7 @@ impl ConnFdReusable for Path {
}
}
#[cfg(unix)]
impl ConnFdReusable for InetSocketAddr {
fn check_fd_match<V: AsRawFd>(&self, fd: V) -> bool {
let fd = fd.as_raw_fd();
@ -281,3 +309,33 @@ impl ConnFdReusable for InetSocketAddr {
}
}
}
#[cfg(windows)]
impl ConnSockReusable for InetSocketAddr {
fn check_sock_match<V: AsRawSocket>(&self, sock: V) -> bool {
let sock = sock.as_raw_socket();
match windows::peer_addr(sock) {
Ok(peer) => {
const ZERO: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0));
if self.ip() == ZERO {
// https://www.rfc-editor.org/rfc/rfc1122.html#section-3.2.1.3
// 0.0.0.0 should only be used as source IP not destination
// However in some systems this destination IP is mapped to 127.0.0.1.
// We just skip this check here to avoid false positive mismatch.
return true;
}
if self == &peer {
debug!("Inet FD to: {self} is reusable");
true
} else {
error!("Crit: FD mismatch: fd: {sock:?}, addr: {self}, peer: {peer}",);
false
}
}
Err(e) => {
debug!("Idle connection is broken: {e:?}");
false
}
}
}
}

View file

@ -93,6 +93,8 @@ impl<T> SslStream<T> {
use std::ops::{Deref, DerefMut};
use super::UniqueIDType;
impl<T> Deref for SslStream<T> {
type Target = InnerSsl<T>;
@ -162,7 +164,7 @@ impl<T> UniqueID for SslStream<T>
where
T: UniqueID,
{
fn id(&self) -> i32 {
fn id(&self) -> UniqueIDType {
self.ssl.get_ref().id()
}
}

View file

@ -0,0 +1,115 @@
//! Windows specific functionality for calling the WinSock c api
//!
//! Implementations here are based on the implementation in the std library
//! https://github.com/rust-lang/rust/blob/84ac80f/library/std/src/sys_common/net.rs
//! https://github.com/rust-lang/rust/blob/84ac80f/library/std/src/sys/pal/windows/net.rs
use std::os::windows::io::RawSocket;
use std::{io, mem, net::SocketAddr};
use windows_sys::Win32::Networking::WinSock::{
getpeername, getsockname, AF_INET, AF_INET6, SOCKADDR_IN, SOCKADDR_IN6, SOCKADDR_STORAGE,
SOCKET,
};
pub(crate) fn peer_addr(raw_sock: RawSocket) -> io::Result<SocketAddr> {
let mut storage = unsafe { mem::zeroed::<SOCKADDR_STORAGE>() };
let mut addrlen = mem::size_of_val(&storage) as i32;
unsafe {
let res = getpeername(
raw_sock as SOCKET,
core::ptr::addr_of_mut!(storage) as *mut _,
&mut addrlen,
);
if res != 0 {
return Err(io::Error::last_os_error());
}
}
sockaddr_to_addr(&storage, addrlen as usize)
}
pub(crate) fn local_addr(raw_sock: RawSocket) -> io::Result<SocketAddr> {
let mut storage = unsafe { mem::zeroed::<SOCKADDR_STORAGE>() };
let mut addrlen = mem::size_of_val(&storage) as i32;
unsafe {
let res = getsockname(
raw_sock as libc::SOCKET,
core::ptr::addr_of_mut!(storage) as *mut _,
&mut addrlen,
);
if res != 0 {
return Err(io::Error::last_os_error());
}
}
sockaddr_to_addr(&storage, addrlen as usize)
}
fn sockaddr_to_addr(storage: &SOCKADDR_STORAGE, len: usize) -> io::Result<SocketAddr> {
match storage.ss_family {
AF_INET => {
assert!(len >= mem::size_of::<SOCKADDR_IN>());
Ok(SocketAddr::from(unsafe {
let sockaddr = *(storage as *const _ as *const SOCKADDR_IN);
(
sockaddr.sin_addr.S_un.S_addr.to_ne_bytes(),
sockaddr.sin_port.to_be(),
)
}))
}
AF_INET6 => {
assert!(len >= mem::size_of::<SOCKADDR_IN6>());
Ok(SocketAddr::from(unsafe {
let sockaddr = *(storage as *const _ as *const SOCKADDR_IN6);
(sockaddr.sin6_addr.u.Byte, sockaddr.sin6_port.to_be())
}))
}
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid argument",
)),
}
}
#[cfg(test)]
mod tests {
use std::os::windows::io::AsRawSocket;
use crate::protocols::l4::{listener::Listener, stream::Stream};
use super::*;
async fn assert_listener_and_stream(addr: &str) {
let tokio_listener = tokio::net::TcpListener::bind(addr).await.unwrap();
let listener_local_addr = tokio_listener.local_addr().unwrap();
let tokio_stream = tokio::net::TcpStream::connect(listener_local_addr)
.await
.unwrap();
let stream_local_addr = tokio_stream.local_addr().unwrap();
let stream_peer_addr = tokio_stream.peer_addr().unwrap();
let stream: Stream = tokio_stream.into();
let listener: Listener = tokio_listener.into();
let raw_sock = listener.as_raw_socket();
assert_eq!(listener_local_addr, local_addr(raw_sock).unwrap());
let raw_sock = stream.as_raw_socket();
assert_eq!(stream_peer_addr, peer_addr(raw_sock).unwrap());
assert_eq!(stream_local_addr, local_addr(raw_sock).unwrap());
}
#[tokio::test]
async fn get_v4_addrs_from_raw_socket() {
assert_listener_and_stream("127.0.0.1:0").await
}
#[tokio::test]
async fn get_v6_addrs_from_raw_socket() {
assert_listener_and_stream("[::1]:0").await
}
}

View file

@ -54,6 +54,7 @@ unsafe fn gid_for_username(name: &CString) -> Option<libc::gid_t> {
}
/// Start a server instance as a daemon.
#[cfg(unix)]
pub fn daemonize(conf: &ServerConf) {
// TODO: customize working dir

View file

@ -15,21 +15,26 @@
//! Server process and configuration management
pub mod configuration;
#[cfg(unix)]
mod daemon;
#[cfg(unix)]
pub(crate) mod transfer_fd;
#[cfg(unix)]
use daemon::daemonize;
use log::{debug, error, info, warn};
use pingora_runtime::Runtime;
use pingora_timeout::fast_timeout;
use std::sync::Arc;
use std::thread;
#[cfg(unix)]
use tokio::signal::unix;
use tokio::sync::{watch, Mutex};
use tokio::time::{sleep, Duration};
use crate::services::Service;
use configuration::{Opt, ServerConf};
#[cfg(unix)]
pub use transfer_fd::Fds;
use pingora_error::{Error, ErrorType, Result};
@ -49,6 +54,7 @@ enum ShutdownType {
/// The receiver for server's shutdown event. The value will turn to true once the server starts
/// to shutdown
pub type ShutdownWatch = watch::Receiver<bool>;
#[cfg(unix)]
pub type ListenFds = Arc<Mutex<Fds>>;
/// The server object
@ -58,6 +64,7 @@ pub type ListenFds = Arc<Mutex<Fds>>;
/// zero downtime upgrade and error reporting.
pub struct Server {
services: Vec<Box<dyn Service>>,
#[cfg(unix)]
listen_fds: Option<ListenFds>,
shutdown_watch: watch::Sender<bool>,
// TODO: we many want to drop this copy to let sender call closed()
@ -75,6 +82,7 @@ pub struct Server {
// TODO: delete the pid when exit
impl Server {
#[cfg(unix)]
async fn main_loop(&self) -> ShutdownType {
// waiting for exit signal
// TODO: there should be a signal handling function
@ -142,7 +150,7 @@ impl Server {
fn run_service(
mut service: Box<dyn Service>,
fds: Option<ListenFds>,
#[cfg(unix)] fds: Option<ListenFds>,
shutdown: ShutdownWatch,
threads: usize,
work_stealing: bool,
@ -152,12 +160,19 @@ impl Server {
{
let service_runtime = Server::create_runtime(service.name(), threads, work_stealing);
service_runtime.get_handle().spawn(async move {
service.start_service(fds, shutdown).await;
service
.start_service(
#[cfg(unix)]
fds,
shutdown,
)
.await;
info!("service exited.")
});
service_runtime
}
#[cfg(unix)]
fn load_fds(&mut self, upgrade: bool) -> Result<(), nix::Error> {
let mut fds = Fds::new();
if upgrade {
@ -185,6 +200,7 @@ impl Server {
Server {
services: vec![],
#[cfg(unix)]
listen_fds: None,
shutdown_watch: tx,
shutdown_recv: rx,
@ -225,6 +241,7 @@ impl Server {
Ok(Server {
services: vec![],
#[cfg(unix)]
listen_fds: None,
shutdown_watch: tx,
shutdown_recv: rx,
@ -267,6 +284,7 @@ impl Server {
}
// load fds
#[cfg(unix)]
match self.load_fds(self.options.as_ref().map_or(false, |o| o.upgrade)) {
Ok(_) => {
info!("Bootstrap done");
@ -294,12 +312,17 @@ impl Server {
let conf = self.configuration.as_ref();
#[cfg(unix)]
if conf.daemon {
info!("Daemonizing the server");
fast_timeout::pause_for_fork();
daemonize(&self.configuration);
fast_timeout::unpause();
}
#[cfg(windows)]
if conf.daemon {
panic!("Daemonizing under windows is not supported");
}
/* only init sentry in release builds */
#[cfg(not(debug_assertions))]
@ -314,6 +337,7 @@ impl Server {
let threads = service.threads().unwrap_or(conf.threads);
let runtime = Server::run_service(
service,
#[cfg(unix)]
self.listen_fds.clone(),
self.shutdown_recv.clone(),
threads,
@ -325,7 +349,10 @@ impl Server {
// blocked on main loop so that it runs forever
// Only work steal runtime can use block_on()
let server_runtime = Server::create_runtime("Server", 1, true);
#[cfg(unix)]
let shutdown_type = server_runtime.get_handle().block_on(self.main_loop());
#[cfg(windows)]
let shutdown_type = ShutdownType::Graceful;
if matches!(shutdown_type, ShutdownType::Graceful) {
let exit_timeout = self

View file

@ -23,7 +23,9 @@ use async_trait::async_trait;
use std::sync::Arc;
use super::Service;
use crate::server::{ListenFds, ShutdownWatch};
#[cfg(unix)]
use crate::server::ListenFds;
use crate::server::ShutdownWatch;
/// The background service interface
#[async_trait]
@ -65,7 +67,11 @@ impl<A> Service for GenBackgroundService<A>
where
A: BackgroundService + Send + Sync + 'static,
{
async fn start_service(&mut self, _fds: Option<ListenFds>, shutdown: ShutdownWatch) {
async fn start_service(
&mut self,
#[cfg(unix)] _fds: Option<ListenFds>,
shutdown: ShutdownWatch,
) {
self.task.start(shutdown).await;
}

View file

@ -21,7 +21,9 @@
use crate::apps::ServerApp;
use crate::listeners::{Listeners, ServerAddress, TcpSocketOptions, TlsSettings, TransportStack};
use crate::protocols::Stream;
use crate::server::{ListenFds, ShutdownWatch};
#[cfg(unix)]
use crate::server::ListenFds;
use crate::server::ShutdownWatch;
use crate::services::Service as ServiceTrait;
use async_trait::async_trait;
@ -83,6 +85,7 @@ impl<A> Service<A> {
///
/// Optionally take a permission of the socket file. The default is read and write access for
/// everyone (0o666).
#[cfg(unix)]
pub fn add_uds(&mut self, addr: &str, perm: Option<Permissions>) {
self.listeners.add_uds(addr, perm);
}
@ -201,9 +204,16 @@ impl<A: ServerApp + Send + Sync + 'static> Service<A> {
#[async_trait]
impl<A: ServerApp + Send + Sync + 'static> ServiceTrait for Service<A> {
async fn start_service(&mut self, fds: Option<ListenFds>, shutdown: ShutdownWatch) {
async fn start_service(
&mut self,
#[cfg(unix)] fds: Option<ListenFds>,
shutdown: ShutdownWatch,
) {
let runtime = current_handle();
let endpoints = self.listeners.build(fds);
let endpoints = self.listeners.build(
#[cfg(unix)]
fds,
);
let app_logic = self
.app_logic
.take()

View file

@ -23,7 +23,9 @@
use async_trait::async_trait;
use crate::server::{ListenFds, ShutdownWatch};
#[cfg(unix)]
use crate::server::ListenFds;
use crate::server::ShutdownWatch;
pub mod background;
pub mod listening;
@ -39,7 +41,11 @@ pub trait Service: Sync + Send {
/// the collection, the service should create its own listening sockets and then put them into
/// the collection in order for them to be passed to the next server.
/// - `shutdown`: the shutdown signal this server would receive.
async fn start_service(&mut self, fds: Option<ListenFds>, mut shutdown: ShutdownWatch);
async fn start_service(
&mut self,
#[cfg(unix)] fds: Option<ListenFds>,
mut shutdown: ShutdownWatch,
);
/// The name of the service, just for logging and naming the threads assigned to this service
///

View file

@ -23,14 +23,17 @@ use std::collections::BTreeMap;
use std::fmt::{Display, Formatter, Result as FmtResult};
use std::hash::{Hash, Hasher};
use std::net::{IpAddr, SocketAddr as InetSocketAddr, ToSocketAddrs as ToInetSocketAddrs};
use std::os::unix::net::SocketAddr as UnixSocketAddr;
use std::os::unix::prelude::AsRawFd;
#[cfg(unix)]
use std::os::unix::{net::SocketAddr as UnixSocketAddr, prelude::AsRawFd};
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use crate::connectors::L4Connect;
use crate::protocols::l4::socket::SocketAddr;
#[cfg(unix)]
use crate::protocols::ConnFdReusable;
use crate::protocols::TcpKeepalive;
use crate::tls::x509::X509;
@ -186,9 +189,15 @@ pub trait Peer: Display + Clone {
.unwrap_or_default()
}
#[cfg(unix)]
fn matches_fd<V: AsRawFd>(&self, fd: V) -> bool {
self.address().check_fd_match(fd)
}
#[cfg(windows)]
fn matches_sock<V: AsRawSocket>(&self, sock: V) -> bool {
use crate::protocols::ConnSockReusable;
self.address().check_sock_match(sock)
}
fn get_tracer(&self) -> Option<Tracer> {
None
@ -211,6 +220,7 @@ impl BasicPeer {
}
/// Create a new [`BasicPeer`] with the given path to a Unix domain socket.
#[cfg(unix)]
pub fn new_uds<P: AsRef<Path>>(path: P) -> Result<Self> {
let addr = SocketAddr::Unix(
UnixSocketAddr::from_pathname(path.as_ref())
@ -450,6 +460,7 @@ impl HttpPeer {
}
/// Create a new [`HttpPeer`] with the given path to Unix domain socket and TLS settings.
#[cfg(unix)]
pub fn new_uds(path: &str, tls: bool, sni: String) -> Result<Self> {
let addr = SocketAddr::Unix(
UnixSocketAddr::from_pathname(Path::new(path)).or_err(SocketError, "invalid path")?,
@ -552,6 +563,7 @@ impl Peer for HttpPeer {
self.proxy.as_ref()
}
#[cfg(unix)]
fn matches_fd<V: AsRawFd>(&self, fd: V) -> bool {
if let Some(proxy) = self.get_proxy() {
proxy.next_hop.check_fd_match(fd)
@ -559,6 +571,16 @@ impl Peer for HttpPeer {
self.address().check_fd_match(fd)
}
}
#[cfg(windows)]
fn matches_sock<V: AsRawSocket>(&self, sock: V) -> bool {
use crate::protocols::ConnSockReusable;
if let Some(proxy) = self.get_proxy() {
panic!("windows do not support peers with proxy")
} else {
self.address().check_sock_match(sock)
}
}
fn get_client_cert_key(&self) -> Option<&Arc<CertKey>> {
self.client_cert_key.as_ref()

View file

@ -15,6 +15,7 @@
mod utils;
use hyper::Client;
#[cfg(unix)]
use hyperlocal::{UnixClientExt, Uri};
use utils::init;
@ -49,6 +50,7 @@ async fn test_https_http2() {
assert_eq!(res.version(), reqwest::Version::HTTP_11);
}
#[cfg(unix)]
#[tokio::test]
async fn test_uds() {
init();

View file

@ -78,6 +78,7 @@ fn entry_point(opt: Option<Opt>) {
my_server.bootstrap();
let mut listeners = Listeners::tcp("0.0.0.0:6145");
#[cfg(unix)]
listeners.add_uds("/tmp/echo.sock", None);
let mut tls_settings =

View file

@ -27,7 +27,10 @@ use tokio::sync::{oneshot, watch, Notify, OwnedMutexGuard};
use super::lru::Lru;
type GroupKey = u64;
#[cfg(unix)]
type ID = i32;
#[cfg(windows)]
type ID = usize;
/// the metadata of a connection
#[derive(Clone, Debug)]

View file

@ -34,6 +34,9 @@ once_cell = { workspace = true }
clap = { version = "3.2.25", features = ["derive"] }
regex = "1"
[target.'cfg(unix)'.dev-dependencies]
hyperlocal = "0.8"
[dev-dependencies]
reqwest = { version = "0.11", features = [
"gzip",
@ -41,7 +44,6 @@ reqwest = { version = "0.11", features = [
], default-features = false }
tokio-test = "0.4"
env_logger = "0.9"
hyperlocal = "0.8"
hyper = "0.14"
tokio-tungstenite = "0.20.1"
pingora-load-balancing = { version = "0.3.0", path = "../pingora-load-balancing" }
@ -61,4 +63,4 @@ boringssl = ["pingora-core/boringssl", "pingora-cache/boringssl"]
rustdoc-args = ["--cfg", "doc_async_trait"]
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(doc_async_trait)'] }
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(doc_async_trait)'] }

View file

@ -116,13 +116,18 @@ impl<SV> HttpProxy<SV> {
SV: ProxyHttp + Send + Sync,
SV::CTX: Send + Sync,
{
#[cfg(windows)]
let raw = client_session.id() as std::os::windows::io::RawSocket;
#[cfg(unix)]
let raw = client_session.id();
if let Err(e) = self
.inner
.connected_to_upstream(
session,
reused,
peer,
client_session.id(),
raw,
Some(client_session.digest()),
ctx,
)

View file

@ -192,16 +192,14 @@ impl<SV> HttpProxy<SV> {
SV: ProxyHttp + Send + Sync,
SV::CTX: Send + Sync,
{
#[cfg(windows)]
let raw = client_session.fd() as std::os::windows::io::RawSocket;
#[cfg(unix)]
let raw = client_session.fd();
if let Err(e) = self
.inner
.connected_to_upstream(
session,
reused,
peer,
client_session.fd(),
client_session.digest(),
ctx,
)
.connected_to_upstream(session, reused, peer, raw, client_session.digest(), ctx)
.await
{
return (false, Some(e));

View file

@ -425,7 +425,8 @@ pub trait ProxyHttp {
_session: &mut Session,
_reused: bool,
_peer: &HttpPeer,
_fd: std::os::unix::io::RawFd,
#[cfg(unix)] _fd: std::os::unix::io::RawFd,
#[cfg(windows)] _sock: std::os::windows::io::RawSocket,
_digest: Option<&Digest>,
_ctx: &mut Self::CTX,
) -> Result<()>

View file

@ -19,6 +19,7 @@ use pingora_cache::lock::WritePermit;
use pingora_core::protocols::raw_connect::ProxyDigest;
use pingora_core::protocols::{
GetProxyDigest, GetSocketDigest, GetTimingDigest, SocketDigest, Ssl, TimingDigest, UniqueID,
UniqueIDType,
};
use std::io::Cursor;
use std::sync::Arc;
@ -68,7 +69,7 @@ impl AsyncWrite for DummyIO {
}
impl UniqueID for DummyIO {
fn id(&self) -> i32 {
fn id(&self) -> UniqueIDType {
0 // placeholder
}
}

View file

@ -15,6 +15,7 @@
mod utils;
use hyper::{body::HttpBody, header::HeaderValue, Body, Client};
#[cfg(unix)]
use hyperlocal::{UnixClientExt, Uri};
use reqwest::{header, StatusCode};
@ -233,6 +234,7 @@ async fn test_h2_to_h1_upload() {
assert_eq!(body, payload);
}
#[cfg(unix)]
#[tokio::test]
async fn test_simple_proxy_uds() {
init();
@ -262,6 +264,7 @@ async fn test_simple_proxy_uds() {
assert_eq!(body.as_ref(), b"Hello World!\n");
}
#[cfg(unix)]
#[tokio::test]
async fn test_simple_proxy_uds_peer() {
init();

View file

@ -192,7 +192,8 @@ impl ProxyHttp for ExampleProxyHttps {
_http_session: &mut Session,
reused: bool,
_peer: &HttpPeer,
_fd: std::os::unix::io::RawFd,
#[cfg(unix)] _fd: std::os::unix::io::RawFd,
#[cfg(windows)] _sock: std::os::windows::io::RawSocket,
digest: Option<&Digest>,
ctx: &mut CTX,
) -> Result<()> {
@ -279,6 +280,7 @@ impl ProxyHttp for ExampleProxyHttp {
_ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
let req = session.req_header();
#[cfg(unix)]
if req.headers.contains_key("x-uds-peer") {
return Ok(Box::new(HttpPeer::new_uds(
"/tmp/nginx-test.sock",
@ -310,7 +312,8 @@ impl ProxyHttp for ExampleProxyHttp {
_http_session: &mut Session,
reused: bool,
_peer: &HttpPeer,
_fd: std::os::unix::io::RawFd,
#[cfg(unix)] _fd: std::os::unix::io::RawFd,
#[cfg(windows)] _sock: std::os::windows::io::RawSocket,
digest: Option<&Digest>,
ctx: &mut CTX,
) -> Result<()> {
@ -527,6 +530,7 @@ fn test_main() {
let mut proxy_service_http =
pingora_proxy::http_proxy_service(&my_server.configuration, ExampleProxyHttp {});
proxy_service_http.add_tcp("0.0.0.0:6147");
#[cfg(unix)]
proxy_service_http.add_uds("/tmp/pingora_proxy.sock", None);
let mut proxy_service_h2c =

View file

@ -29,15 +29,17 @@ pingora-load-balancing = { version = "0.3.0", path = "../pingora-load-balancing"
pingora-proxy = { version = "0.3.0", path = "../pingora-proxy", optional = true, default-features = false }
pingora-cache = { version = "0.3.0", path = "../pingora-cache", optional = true, default-features = false }
[target.'cfg(unix)'.dev-dependencies]
hyperlocal = "0.8"
jemallocator = "0.5"
[dev-dependencies]
clap = { version = "3.2.25", features = ["derive"] }
tokio = { workspace = true, features = ["rt-multi-thread", "signal"] }
matches = "0.1"
env_logger = "0.9"
reqwest = { version = "0.11", features = ["rustls"], default-features = false }
hyperlocal = "0.8"
hyper = "0.14"
jemallocator = "0.5"
async-trait = { workspace = true }
http = { workspace = true }
log = { workspace = true }