diff --git a/.bleep b/.bleep index 8439eac..b99e531 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -837db6c7ec2d37abf83f9588be99fda00e2012c3 \ No newline at end of file +c90e4ce2596840c60b5ff1737e2141447e5953e1 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 477b010..81bc51a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -46,7 +46,7 @@ jobs: - name: Run cargo clippy run: | - [[ ${{ matrix.toolchain }} == nightly ]] || cargo clippy --all-targets --all -- --deny=warnings + [[ ${{ matrix.toolchain }} == nightly ]] || cargo clippy --all-targets --all -- --allow=unknown-lints --deny=warnings - name: Run cargo audit run: | diff --git a/pingora-cache/src/eviction/lru.rs b/pingora-cache/src/eviction/lru.rs index 91bc8fc..fa09fab 100644 --- a/pingora-cache/src/eviction/lru.rs +++ b/pingora-cache/src/eviction/lru.rs @@ -32,6 +32,7 @@ use std::time::SystemTime; /// /// - Space optimized in-memory LRU (see [pingora_lru]). /// - Instead of a single giant LRU, this struct shards the assets into `N` independent LRUs. +/// /// This allows [EvictionManager::save()] not to lock the entire cache manager while performing /// serialization. pub struct Manager(Lru); diff --git a/pingora-cache/src/lib.rs b/pingora-cache/src/lib.rs index 6ba057e..86cfbdf 100644 --- a/pingora-cache/src/lib.rs +++ b/pingora-cache/src/lib.rs @@ -344,10 +344,10 @@ impl HttpCache { /// - `storage`: the cache storage backend that implements [storage::Storage] /// - `eviction`: optionally the eviction manager, without it, nothing will be evicted from the storage /// - `predictor`: optionally a cache predictor. The cache predictor predicts whether something is likely - /// to be cacheable or not. This is useful because the proxy can apply different types of optimization to - /// cacheable and uncacheable requests. + /// to be cacheable or not. This is useful because the proxy can apply different types of optimization to + /// cacheable and uncacheable requests. /// - `cache_lock`: optionally a cache lock which handles concurrent lookups to the same asset. Without it - /// such lookups will all be allowed to fetch the asset independently. + /// such lookups will all be allowed to fetch the asset independently. pub fn enable( &mut self, storage: &'static (dyn storage::Storage + Sync), diff --git a/pingora-core/Cargo.toml b/pingora-core/Cargo.toml index c13823c..c5bce77 100644 --- a/pingora-core/Cargo.toml +++ b/pingora-core/Cargo.toml @@ -81,4 +81,4 @@ jemallocator = "0.5" default = ["openssl"] openssl = ["pingora-openssl"] boringssl = ["pingora-boringssl"] -patched_http1 = [] +patched_http1 = [] \ No newline at end of file diff --git a/pingora-core/src/apps/http_app.rs b/pingora-core/src/apps/http_app.rs index 6983e8e..91ca58a 100644 --- a/pingora-core/src/apps/http_app.rs +++ b/pingora-core/src/apps/http_app.rs @@ -28,7 +28,7 @@ use crate::protocols::Stream; use crate::server::ShutdownWatch; /// This trait defines how to map a request to a response -#[cfg_attr(not(doc_async_trait), async_trait)] +#[async_trait] pub trait ServeHttp { /// Define the mapping from a request to a response. /// Note that the request header is already read, but the implementation needs to read the @@ -42,7 +42,7 @@ pub trait ServeHttp { } // TODO: remove this in favor of HttpServer? -#[cfg_attr(not(doc_async_trait), async_trait)] +#[async_trait] impl HttpServerApp for SV where SV: ServeHttp + Send + Sync, @@ -128,7 +128,7 @@ impl HttpServer { } } -#[cfg_attr(not(doc_async_trait), async_trait)] +#[async_trait] impl HttpServerApp for HttpServer where SV: ServeHttp + Send + Sync, diff --git a/pingora-core/src/apps/mod.rs b/pingora-core/src/apps/mod.rs index 758eb37..32fd82f 100644 --- a/pingora-core/src/apps/mod.rs +++ b/pingora-core/src/apps/mod.rs @@ -28,7 +28,7 @@ use crate::protocols::Digest; use crate::protocols::Stream; use crate::protocols::ALPN; -#[cfg_attr(not(doc_async_trait), async_trait)] +#[async_trait] /// This trait defines the interface of a transport layer (TCP or TLS) application. pub trait ServerApp { /// Whenever a new connection is established, this function will be called with the established @@ -62,7 +62,7 @@ pub struct HttpServerOptions { } /// This trait defines the interface of an HTTP application. -#[cfg_attr(not(doc_async_trait), async_trait)] +#[async_trait] pub trait HttpServerApp { /// Similar to the [`ServerApp`], this function is called whenever a new HTTP session is established. /// @@ -95,7 +95,7 @@ pub trait HttpServerApp { async fn http_cleanup(&self) {} } -#[cfg_attr(not(doc_async_trait), async_trait)] +#[async_trait] impl ServerApp for T where T: HttpServerApp + Send + Sync + 'static, diff --git a/pingora-core/src/apps/prometheus_http_app.rs b/pingora-core/src/apps/prometheus_http_app.rs index 38072bb..128ba6c 100644 --- a/pingora-core/src/apps/prometheus_http_app.rs +++ b/pingora-core/src/apps/prometheus_http_app.rs @@ -29,7 +29,7 @@ use crate::protocols::http::ServerSession; /// collected via the [Prometheus](https://docs.rs/prometheus/) crate; pub struct PrometheusHttpApp; -#[cfg_attr(not(doc_async_trait), async_trait)] +#[async_trait] impl ServeHttp for PrometheusHttpApp { async fn response(&self, _http_session: &mut ServerSession) -> Response> { let encoder = TextEncoder::new(); diff --git a/pingora-core/src/connectors/l4.rs b/pingora-core/src/connectors/l4.rs index 449ea4c..d226a8b 100644 --- a/pingora-core/src/connectors/l4.rs +++ b/pingora-core/src/connectors/l4.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_trait::async_trait; use log::debug; use pingora_error::{Context, Error, ErrorType::*, OrErr, Result}; use rand::seq::SliceRandom; @@ -26,6 +27,12 @@ use crate::protocols::l4::stream::Stream; use crate::protocols::{GetSocketDigest, SocketDigest}; use crate::upstreams::peer::Peer; +/// The interface to establish a L4 connection +#[async_trait] +pub trait Connect: std::fmt::Debug { + async fn connect(&self, addr: &SocketAddr) -> Result; +} + /// Establish a connection (l4) to the given peer using its settings and an optional bind address. pub async fn connect

(peer: &P, bind_to: Option) -> Result where @@ -37,72 +44,78 @@ where .err_context(|| format!("Fail to establish CONNECT proxy: {}", peer)); } let peer_addr = peer.address(); - let mut stream: Stream = match peer_addr { - SocketAddr::Inet(addr) => { - let connect_future = tcp_connect(addr, bind_to.as_ref(), |socket| { - if peer.tcp_fast_open() { - set_tcp_fastopen_connect(socket.as_raw_fd())?; - } - if let Some(recv_buf) = peer.tcp_recv_buf() { - debug!("Setting recv buf size"); - set_recv_buf(socket.as_raw_fd(), recv_buf)?; - } - if let Some(dscp) = peer.dscp() { - debug!("Setting dscp"); - set_dscp(socket.as_raw_fd(), dscp)?; - } - Ok(()) - }); - let conn_res = match peer.connection_timeout() { - Some(t) => pingora_timeout::timeout(t, connect_future) - .await - .explain_err(ConnectTimedout, |_| { - format!("timeout {t:?} connecting to server {peer}") - })?, - None => connect_future.await, - }; - match conn_res { - Ok(socket) => { - debug!("connected to new server: {}", peer.address()); - Ok(socket.into()) - } - Err(e) => { - let c = format!("Fail to connect to {peer}"); - match e.etype() { - SocketError | BindError => Error::e_because(InternalError, c, e), - _ => Err(e.more_context(c)), + let mut stream: Stream = + if let Some(custom_l4) = peer.get_peer_options().and_then(|o| o.custom_l4.as_ref()) { + custom_l4.connect(peer_addr).await? + } else { + match peer_addr { + SocketAddr::Inet(addr) => { + let connect_future = tcp_connect(addr, bind_to.as_ref(), |socket| { + if peer.tcp_fast_open() { + set_tcp_fastopen_connect(socket.as_raw_fd())?; + } + if let Some(recv_buf) = peer.tcp_recv_buf() { + debug!("Setting recv buf size"); + set_recv_buf(socket.as_raw_fd(), recv_buf)?; + } + if let Some(dscp) = peer.dscp() { + debug!("Setting dscp"); + set_dscp(socket.as_raw_fd(), dscp)?; + } + Ok(()) + }); + let conn_res = match peer.connection_timeout() { + Some(t) => pingora_timeout::timeout(t, connect_future) + .await + .explain_err(ConnectTimedout, |_| { + format!("timeout {t:?} connecting to server {peer}") + })?, + None => connect_future.await, + }; + match conn_res { + Ok(socket) => { + debug!("connected to new server: {}", peer.address()); + Ok(socket.into()) + } + Err(e) => { + let c = format!("Fail to connect to {peer}"); + match e.etype() { + SocketError | BindError => Error::e_because(InternalError, c, e), + _ => Err(e.more_context(c)), + } + } } } - } - } - SocketAddr::Unix(addr) => { - let connect_future = connect_uds( - addr.as_pathname() - .expect("non-pathname unix sockets not supported as peer"), - ); - let conn_res = match peer.connection_timeout() { - Some(t) => pingora_timeout::timeout(t, connect_future) - .await - .explain_err(ConnectTimedout, |_| { - format!("timeout {t:?} connecting to server {peer}") - })?, - None => connect_future.await, - }; - match conn_res { - Ok(socket) => { - debug!("connected to new server: {}", peer.address()); - Ok(socket.into()) - } - Err(e) => { - let c = format!("Fail to connect to {peer}"); - match e.etype() { - SocketError | BindError => Error::e_because(InternalError, c, e), - _ => Err(e.more_context(c)), + SocketAddr::Unix(addr) => { + let connect_future = connect_uds( + addr.as_pathname() + .expect("non-pathname unix sockets not supported as peer"), + ); + let conn_res = match peer.connection_timeout() { + Some(t) => pingora_timeout::timeout(t, connect_future) + .await + .explain_err(ConnectTimedout, |_| { + format!("timeout {t:?} connecting to server {peer}") + })?, + None => connect_future.await, + }; + match conn_res { + Ok(socket) => { + debug!("connected to new server: {}", peer.address()); + Ok(socket.into()) + } + Err(e) => { + let c = format!("Fail to connect to {peer}"); + match e.etype() { + SocketError | BindError => Error::e_because(InternalError, c, e), + _ => Err(e.more_context(c)), + } + } } } - } - } - }?; + }? + }; + let tracer = peer.get_tracer(); if let Some(t) = tracer { t.0.on_connected(); @@ -249,6 +262,29 @@ mod tests { assert_eq!(new_session.unwrap_err().etype(), &ConnectTimedout) } + #[tokio::test] + async fn test_custom_connect() { + #[derive(Debug)] + struct MyL4; + #[async_trait] + impl Connect for MyL4 { + async fn connect(&self, _addr: &SocketAddr) -> Result { + tokio::net::TcpStream::connect("1.1.1.1:80") + .await + .map(|s| s.into()) + .or_fail() + } + } + // :79 shouldn't be able to be connected to + let mut peer = BasicPeer::new("1.1.1.1:79"); + peer.options.custom_l4 = Some(std::sync::Arc::new(MyL4 {})); + + let new_session = connect(&peer, None).await; + + // but MyL4 connects to :80 instead + assert!(new_session.is_ok()); + } + #[tokio::test] async fn test_connect_proxy_fail() { let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string()); diff --git a/pingora-core/src/connectors/mod.rs b/pingora-core/src/connectors/mod.rs index 9bfd91b..d13f3a9 100644 --- a/pingora-core/src/connectors/mod.rs +++ b/pingora-core/src/connectors/mod.rs @@ -25,6 +25,7 @@ use crate::tls::ssl::SslConnector; use crate::upstreams::peer::{Peer, ALPN}; use l4::connect as l4_connect; +pub use l4::Connect as L4Connect; use log::{debug, error, warn}; use offload::OffloadRuntime; use parking_lot::RwLock; diff --git a/pingora-core/src/lib.rs b/pingora-core/src/lib.rs index 3435fe5..4c3e29a 100644 --- a/pingora-core/src/lib.rs +++ b/pingora-core/src/lib.rs @@ -18,8 +18,6 @@ #![allow(clippy::match_wild_err_arm)] #![allow(clippy::missing_safety_doc)] #![allow(clippy::upper_case_acronyms)] -// enable nightly feature async trait so that the docs are cleaner -#![cfg_attr(doc_async_trait, feature(async_fn_in_trait))] //! # Pingora //! diff --git a/pingora-core/src/protocols/http/server.rs b/pingora-core/src/protocols/http/server.rs index c6479e7..74aada6 100644 --- a/pingora-core/src/protocols/http/server.rs +++ b/pingora-core/src/protocols/http/server.rs @@ -53,7 +53,7 @@ impl Session { /// else with the session. /// - `Ok(true)`: successful /// - `Ok(false)`: client exit without sending any bytes. This is normal on reused connection. - /// In this case the user should give up this session. + /// In this case the user should give up this session. pub async fn read_request(&mut self) -> Result { match self { Self::H1(s) => { @@ -218,6 +218,19 @@ impl Session { } } + /// Sets whether we ignore writing informational responses downstream. + /// + /// For HTTP/1.1 this is a noop if the response is Upgrade or Continue and + /// Expect: 100-continue was set on the request. + /// + /// This is a noop for h2 because informational responses are always ignored. + pub fn set_ignore_info_resp(&mut self, ignore: bool) { + match self { + Self::H1(s) => s.set_ignore_info_resp(ignore), + Self::H2(_) => {} // always ignored + } + } + /// Return a digest of the request including the method, path and Host header // TODO: make this use a `Formatter` pub fn request_summary(&self) -> String { diff --git a/pingora-core/src/protocols/http/v1/common.rs b/pingora-core/src/protocols/http/v1/common.rs index 18fa6c0..d6ade98 100644 --- a/pingora-core/src/protocols/http/v1/common.rs +++ b/pingora-core/src/protocols/http/v1/common.rs @@ -153,6 +153,14 @@ pub(super) fn is_upgrade_req(req: &RequestHeader) -> bool { req.version == http::Version::HTTP_11 && req.headers.get(header::UPGRADE).is_some() } +pub(super) fn is_expect_continue_req(req: &RequestHeader) -> bool { + req.version == http::Version::HTTP_11 + // https://www.rfc-editor.org/rfc/rfc9110#section-10.1.1 + && req.headers.get(header::EXPECT).map_or(false, |v| { + v.as_bytes().eq_ignore_ascii_case(b"100-continue") + }) +} + // Unlike the upgrade check on request, this function doesn't check the Upgrade or Connection header // because when seeing 101, we assume the server accepts to switch protocol. // In reality it is not common that some servers don't send all the required headers to establish diff --git a/pingora-core/src/protocols/http/v1/server.rs b/pingora-core/src/protocols/http/v1/server.rs index 82a93a2..ea99d74 100644 --- a/pingora-core/src/protocols/http/v1/server.rs +++ b/pingora-core/src/protocols/http/v1/server.rs @@ -74,6 +74,8 @@ pub struct HttpSession { digest: Box, /// Minimum send rate to the client min_send_rate: Option, + /// When this is enabled informational response headers will not be proxied downstream + ignore_info_resp: bool, } impl HttpSession { @@ -109,6 +111,7 @@ impl HttpSession { upgraded: false, digest, min_send_rate: None, + ignore_info_resp: false, } } @@ -388,6 +391,11 @@ impl HttpSession { /// Write the response header to the client. /// This function can be called more than once to send 1xx informational headers excluding 101. pub async fn write_response_header(&mut self, mut header: Box) -> Result<()> { + if header.status.is_informational() && self.ignore_info_resp(header.status.into()) { + debug!("ignoring informational headers"); + return Ok(()); + } + if let Some(resp) = self.response_written.as_ref() { if !resp.status.is_informational() || self.upgraded { warn!("Respond header is already sent, cannot send again"); @@ -409,7 +417,7 @@ impl HttpSession { header.insert_header(header::CONNECTION, connection_value)?; } - if header.status.as_u16() == 101 { + if header.status == 101 { // make sure the connection is closed at the end when 101/upgrade is used self.set_keepalive(None); } @@ -510,6 +518,18 @@ impl HttpSession { (None, None) } + fn ignore_info_resp(&self, status: u16) -> bool { + // ignore informational response if ignore flag is set and it's not an Upgrade and Expect: 100-continue isn't set + self.ignore_info_resp && status != 101 && !(status == 100 && self.is_expect_continue_req()) + } + + fn is_expect_continue_req(&self) -> bool { + match self.request_header.as_deref() { + Some(req) => is_expect_continue_req(req), + None => false, + } + } + fn is_connection_keepalive(&self) -> Option { is_buf_keepalive(self.get_header(header::CONNECTION)) } @@ -824,6 +844,14 @@ impl HttpSession { } } + /// Sets whether we ignore writing informational responses downstream. + /// + /// This is a noop if the response is Upgrade or Continue and + /// Expect: 100-continue was set on the request. + pub fn set_ignore_info_resp(&mut self, ignore: bool) { + self.ignore_info_resp = ignore; + } + /// Return the [Digest] of the connection. pub fn digest(&self) -> &Digest { &self.digest @@ -1472,6 +1500,75 @@ mod tests_stream { .unwrap(); } + #[tokio::test] + async fn write_informational_ignored() { + let wire = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n"; + let mock_io = Builder::new().write(wire).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + // ignore the 100 Continue + http_stream.ignore_info_resp = true; + let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap(); + http_stream + .write_response_header_ref(&response_100) + .await + .unwrap(); + let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap(); + response_200.append_header("Foo", "Bar").unwrap(); + http_stream.update_resp_headers = false; + http_stream + .write_response_header_ref(&response_200) + .await + .unwrap(); + } + + #[tokio::test] + async fn write_informational_100_not_ignored_if_expect_continue() { + let input = b"GET / HTTP/1.1\r\nExpect: 100-continue\r\n\r\n"; + let output = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n"; + + let mock_io = Builder::new().read(&input[..]).write(output).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + http_stream.ignore_info_resp = true; + // 100 Continue is not ignored due to Expect: 100-continue on request + let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap(); + http_stream + .write_response_header_ref(&response_100) + .await + .unwrap(); + let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap(); + response_200.append_header("Foo", "Bar").unwrap(); + http_stream.update_resp_headers = false; + http_stream + .write_response_header_ref(&response_200) + .await + .unwrap(); + } + + #[tokio::test] + async fn write_informational_1xx_ignored_if_expect_continue() { + let input = b"GET / HTTP/1.1\r\nExpect: 100-continue\r\n\r\n"; + let output = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n"; + + let mock_io = Builder::new().read(&input[..]).write(output).build(); + let mut http_stream = HttpSession::new(Box::new(mock_io)); + http_stream.read_request().await.unwrap(); + http_stream.ignore_info_resp = true; + // 102 Processing is ignored + let response_102 = ResponseHeader::build(StatusCode::PROCESSING, None).unwrap(); + http_stream + .write_response_header_ref(&response_102) + .await + .unwrap(); + let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap(); + response_200.append_header("Foo", "Bar").unwrap(); + http_stream.update_resp_headers = false; + http_stream + .write_response_header_ref(&response_200) + .await + .unwrap(); + } + #[tokio::test] async fn write_101_switching_protocol() { let wire = b"HTTP/1.1 101 Switching Protocols\r\nFoo: Bar\r\n\r\n"; diff --git a/pingora-core/src/protocols/http/v2/server.rs b/pingora-core/src/protocols/http/v2/server.rs index f9072d6..5bd6bc0 100644 --- a/pingora-core/src/protocols/http/v2/server.rs +++ b/pingora-core/src/protocols/http/v2/server.rs @@ -194,22 +194,21 @@ impl HttpSession { return Ok(()); } - // FIXME: we should ignore 1xx header because send_response() can only be called once - // https://github.com/hyperium/h2/issues/167 - - if let Some(resp) = self.response_written.as_ref() { - if !resp.status.is_informational() { - warn!("Respond header is already sent, cannot send again"); - return Ok(()); - } + if header.status.is_informational() { + // ignore informational response 1xx header because send_response() can only be called once + // https://github.com/hyperium/h2/issues/167 + debug!("ignoring informational headers"); + return Ok(()); } - // no need to add these headers to 1xx responses - if !header.status.is_informational() { - /* update headers */ - header.insert_header(header::DATE, get_cached_date())?; + if self.response_written.as_ref().is_some() { + warn!("Response header is already sent, cannot send again"); + return Ok(()); } + /* update headers */ + header.insert_header(header::DATE, get_cached_date())?; + // remove other h1 hop headers that cannot be present in H2 // https://httpwg.org/specs/rfc7540.html#n-connection-specific-header-fields header.remove_header(&header::TRANSFER_ENCODING); @@ -486,7 +485,8 @@ mod test { expected_trailers.insert("test", HeaderValue::from_static("trailers")); let trailers = expected_trailers.clone(); - tokio::spawn(async move { + let mut handles = vec![]; + handles.push(tokio::spawn(async move { let (h2, connection) = h2::client::handshake(client).await.unwrap(); tokio::spawn(async move { connection.await.unwrap(); @@ -510,7 +510,7 @@ mod test { assert_eq!(data, server_body); let resp_trailers = body.trailers().await.unwrap().unwrap(); assert_eq!(resp_trailers, expected_trailers); - }); + })); let mut connection = handshake(Box::new(server), None).await.unwrap(); let digest = Arc::new(Digest::default()); @@ -520,7 +520,7 @@ mod test { .unwrap() { let trailers = trailers.clone(); - tokio::spawn(async move { + handles.push(tokio::spawn(async move { let req = http.req_header(); assert_eq!(req.method, Method::GET); assert_eq!(req.uri, "https://www.example.com/"); @@ -545,7 +545,11 @@ mod test { } let response_header = Box::new(ResponseHeader::build(200, None).unwrap()); - http.write_response_header(response_header, false).unwrap(); + assert!(http + .write_response_header(response_header.clone(), false) + .is_ok()); + // this write should be ignored otherwise we will error + assert!(http.write_response_header(response_header, false).is_ok()); // test idling after response header is sent tokio::select! { @@ -559,7 +563,11 @@ mod test { http.write_trailers(trailers).unwrap(); http.finish().unwrap(); - }); + })); + } + for handle in handles { + // ensure no panics + assert!(handle.await.is_ok()); } } } diff --git a/pingora-core/src/protocols/l4/ext.rs b/pingora-core/src/protocols/l4/ext.rs index 56af522..5123bdf 100644 --- a/pingora-core/src/protocols/l4/ext.rs +++ b/pingora-core/src/protocols/l4/ext.rs @@ -369,13 +369,10 @@ fn wrap_os_connect_error(e: std::io::Error, context: String) -> Box { Error::because(InternalError, context, e) } _ => match e.raw_os_error() { - Some(code) => match code { - libc::ENETUNREACH | libc::EHOSTUNREACH => { - Error::because(ConnectNoRoute, context, e) - } - _ => Error::because(ConnectError, context, e), - }, - None => Error::because(ConnectError, context, e), + Some(libc::ENETUNREACH | libc::EHOSTUNREACH) => { + Error::because(ConnectNoRoute, context, e) + } + _ => Error::because(ConnectError, context, e), }, } } diff --git a/pingora-core/src/services/background.rs b/pingora-core/src/services/background.rs index 4eec577..2b84532 100644 --- a/pingora-core/src/services/background.rs +++ b/pingora-core/src/services/background.rs @@ -26,7 +26,7 @@ use super::Service; use crate::server::{ListenFds, ShutdownWatch}; /// The background service interface -#[cfg_attr(not(doc_async_trait), async_trait)] +#[async_trait] pub trait BackgroundService { /// This function is called when the pingora server tries to start all the /// services. The background service can return at anytime or wait for the diff --git a/pingora-core/src/upstreams/peer.rs b/pingora-core/src/upstreams/peer.rs index d0c8125..c4a23a8 100644 --- a/pingora-core/src/upstreams/peer.rs +++ b/pingora-core/src/upstreams/peer.rs @@ -29,6 +29,7 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::Duration; +use crate::connectors::L4Connect; use crate::protocols::l4::socket::SocketAddr; use crate::protocols::ConnFdReusable; use crate::protocols::TcpKeepalive; @@ -322,6 +323,8 @@ pub struct PeerOptions { pub tcp_fast_open: bool, // use Arc because Clone is required but not allowed in trait object pub tracer: Option, + // A custom L4 connector to use to establish new L4 connections + pub custom_l4: Option>, } impl PeerOptions { @@ -350,6 +353,7 @@ impl PeerOptions { second_keyshare: true, // default true and noop when not using PQ curves tcp_fast_open: false, tracer: None, + custom_l4: None, } } diff --git a/pingora-load-balancing/Cargo.toml b/pingora-load-balancing/Cargo.toml index 904e9f6..523e2f4 100644 --- a/pingora-load-balancing/Cargo.toml +++ b/pingora-load-balancing/Cargo.toml @@ -29,6 +29,8 @@ rand = "0" tokio = { workspace = true } futures = "0" log = { workspace = true } +http = { workspace = true } +derivative = "2.2.0" [dev-dependencies] diff --git a/pingora-load-balancing/src/discovery.rs b/pingora-load-balancing/src/discovery.rs index 5a38c2f..0c1ebdd 100644 --- a/pingora-load-balancing/src/discovery.rs +++ b/pingora-load-balancing/src/discovery.rs @@ -16,6 +16,7 @@ use arc_swap::ArcSwap; use async_trait::async_trait; +use http::Extensions; use pingora_core::protocols::l4::socket::SocketAddr; use pingora_error::Result; use std::io::Result as IoResult; @@ -62,6 +63,7 @@ impl Static { let addrs = addrs.to_socket_addrs()?.map(|addr| Backend { addr: SocketAddr::Inet(addr), weight: 1, + ext: Extensions::new(), }); upstreams.extend(addrs); } diff --git a/pingora-load-balancing/src/health_check.rs b/pingora-load-balancing/src/health_check.rs index dfd3ba0..5184e9e 100644 --- a/pingora-load-balancing/src/health_check.rs +++ b/pingora-load-balancing/src/health_check.rs @@ -154,9 +154,9 @@ pub struct HttpHealthCheck { /// Whether the underlying TCP/TLS connection can be reused across checks. /// /// * `false` will make sure that every health check goes through TCP (and TLS) handshakes. - /// Established connections sometimes hide the issue of firewalls and L4 LB. + /// Established connections sometimes hide the issue of firewalls and L4 LB. /// * `true` will try to reuse connections across checks, this is the more efficient and fast way - /// to perform health checks. + /// to perform health checks. pub reuse_connection: bool, /// The request header to send to the backend pub req: RequestHeader, @@ -350,6 +350,7 @@ mod test { use super::*; use crate::{discovery, Backends, SocketAddr}; use async_trait::async_trait; + use http::Extensions; #[tokio::test] async fn test_tcp_check() { @@ -358,6 +359,7 @@ mod test { let backend = Backend { addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()), weight: 1, + ext: Extensions::new(), }; assert!(tcp_check.check(&backend).await.is_ok()); @@ -365,6 +367,7 @@ mod test { let backend = Backend { addr: SocketAddr::Inet("1.1.1.1:79".parse().unwrap()), weight: 1, + ext: Extensions::new(), }; assert!(tcp_check.check(&backend).await.is_err()); @@ -376,6 +379,7 @@ mod test { let backend = Backend { addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()), weight: 1, + ext: Extensions::new(), }; assert!(tls_check.check(&backend).await.is_ok()); @@ -388,6 +392,7 @@ mod test { let backend = Backend { addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()), weight: 1, + ext: Extensions::new(), }; assert!(https_check.check(&backend).await.is_ok()); @@ -410,6 +415,7 @@ mod test { let backend = Backend { addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()), weight: 1, + ext: Extensions::new(), }; http_check.check(&backend).await.unwrap(); diff --git a/pingora-load-balancing/src/lib.rs b/pingora-load-balancing/src/lib.rs index 2777f9f..f14d2ee 100644 --- a/pingora-load-balancing/src/lib.rs +++ b/pingora-load-balancing/src/lib.rs @@ -16,8 +16,14 @@ //! This crate provides common service discovery, health check and load balancing //! algorithms for proxies to use. +// https://github.com/mcarton/rust-derivative/issues/112 +// False positive for macro generated code +#![allow(clippy::non_canonical_partial_ord_impl)] + use arc_swap::ArcSwap; +use derivative::Derivative; use futures::FutureExt; +pub use http::Extensions; use pingora_core::protocols::l4::socket::SocketAddr; use pingora_error::{ErrorType, OrErr, Result}; use std::collections::hash_map::DefaultHasher; @@ -45,13 +51,26 @@ pub mod prelude { } /// [Backend] represents a server to proxy or connect to. -#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)] +#[derive(Derivative)] +#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)] pub struct Backend { /// The address to the backend server. pub addr: SocketAddr, /// The relative weight of the server. Load balancing algorithms will /// proportionally distributed traffic according to this value. pub weight: usize, + + /// The extension field to put arbitrary data to annotate the Backend. + /// The data added here is opaque to this crate hence the data is ignored by + /// functionalities of this crate. For example, two backends with the same + /// [SocketAddr] and the same weight but different `ext` data are considered + /// identical. + /// See [Extensions] for how to add and read the data. + #[derivative(PartialEq = "ignore")] + #[derivative(PartialOrd = "ignore")] + #[derivative(Hash = "ignore")] + #[derivative(Ord = "ignore")] + pub ext: Extensions, } impl Backend { @@ -64,6 +83,7 @@ impl Backend { Ok(Backend { addr: SocketAddr::Inet(addr), weight: 1, + ext: Extensions::new(), }) // TODO: UDS } @@ -130,8 +150,17 @@ impl Backends { self.health_check = Some(hc.into()) } - /// Return true when the new is different from the current set of backends - fn do_update(&self, new_backends: BTreeSet, enablement: HashMap) -> bool { + /// Updates backends when the new is different from the current set, + /// the callback will be invoked when the new set of backend is different + /// from the current one so that the caller can update the selector accordingly. + fn do_update( + &self, + new_backends: BTreeSet, + enablement: HashMap, + callback: F, + ) where + F: Fn(Arc>), + { if (**self.backends.load()) != new_backends { let old_health = self.health.load(); let mut health = HashMap::with_capacity(new_backends.len()); @@ -147,10 +176,14 @@ impl Backends { health.insert(hash_key, backend_health); } - // TODO: put backend and health under 1 ArcSwap so that this update is atomic - self.backends.store(Arc::new(new_backends)); + // TODO: put this all under 1 ArcSwap so the update is atomic + // It's important the `callback()` executes first since computing selector backends might + // be expensive. For example, if a caller checks `backends` to see if any are available + // they may encounter false positives if the selector isn't ready yet. + let new_backends = Arc::new(new_backends); + callback(new_backends.clone()); + self.backends.store(new_backends); self.health.store(Arc::new(health)); - true } else { // no backend change, just check enablement for (hash_key, backend_enabled) in enablement.iter() { @@ -160,7 +193,6 @@ impl Backends { backend_health.enable(*backend_enabled); } } - false } } @@ -199,12 +231,15 @@ impl Backends { /// Call the service discovery method to update the collection of backends. /// - /// Return `true` when the new collection is different from the current set of backends. - /// This return value is useful to tell the caller when to rebuild things that are expensive to - /// update, such as consistent hashing rings. - pub async fn update(&self) -> Result { + /// The callback will be invoked when the new set of backend is different + /// from the current one so that the caller can update the selector accordingly. + pub async fn update(&self, callback: F) -> Result<()> + where + F: Fn(Arc>), + { let (new_backends, enablement) = self.discovery.discover().await?; - Ok(self.do_update(new_backends, enablement)) + self.do_update(new_backends, enablement, callback); + Ok(()) } /// Run health check on all backends if it is set. @@ -321,11 +356,9 @@ where /// This function will be called every `update_frequency` if this [LoadBalancer] instance /// is running as a background service. pub async fn update(&self) -> Result<()> { - if self.backends.update().await? { - self.selector - .store(Arc::new(S::build(&self.backends.get_backend()))) - } - Ok(()) + self.backends + .update(|backends| self.selector.store(Arc::new(S::build(&backends)))) + .await } /// Return the first healthy [Backend] according to the selection algorithm and the @@ -379,6 +412,8 @@ where #[cfg(test)] mod test { + use std::sync::atomic::{AtomicBool, Ordering::Relaxed}; + use super::*; use async_trait::async_trait; @@ -409,10 +444,20 @@ mod test { backends.set_health_check(check); // true: new backend discovered - assert!(backends.update().await.unwrap()); + let updated = AtomicBool::new(false); + backends + .update(|_| updated.store(true, Relaxed)) + .await + .unwrap(); + assert!(updated.load(Relaxed)); // false: no new backend discovered - assert!(!backends.update().await.unwrap()); + let updated = AtomicBool::new(false); + backends + .update(|_| updated.store(true, Relaxed)) + .await + .unwrap(); + assert!(!updated.load(Relaxed)); backends.run_health_check(false).await; @@ -425,6 +470,31 @@ mod test { assert!(backends.ready(&good2)); assert!(!backends.ready(&bad)); } + #[tokio::test] + async fn test_backends_with_ext() { + let discovery = discovery::Static::default(); + let mut b1 = Backend::new("1.1.1.1:80").unwrap(); + b1.ext.insert(true); + let mut b2 = Backend::new("1.0.0.1:80").unwrap(); + b2.ext.insert(1u8); + discovery.add(b1.clone()); + discovery.add(b2.clone()); + + let backends = Backends::new(Box::new(discovery)); + + // fill in the backends + backends.update(|_| {}).await.unwrap(); + + let backend = backends.get_backend(); + assert!(backend.contains(&b1)); + assert!(backend.contains(&b2)); + + let b2 = backend.first().unwrap(); + assert_eq!(b2.ext.get::(), Some(&1)); + + let b1 = backend.last().unwrap(); + assert_eq!(b1.ext.get::(), Some(&true)); + } #[tokio::test] async fn test_discovery_readiness() { @@ -450,7 +520,14 @@ mod test { let discovery = TestDiscovery(discovery); let backends = Backends::new(Box::new(discovery)); - assert!(backends.update().await.unwrap()); + + // true: new backend discovered + let updated = AtomicBool::new(false); + backends + .update(|_| updated.store(true, Relaxed)) + .await + .unwrap(); + assert!(updated.load(Relaxed)); let backend = backends.get_backend(); assert!(backend.contains(&good1)); @@ -477,7 +554,12 @@ mod test { backends.set_health_check(check); // true: new backend discovered - assert!(backends.update().await.unwrap()); + let updated = AtomicBool::new(false); + backends + .update(|_| updated.store(true, Relaxed)) + .await + .unwrap(); + assert!(updated.load(Relaxed)); backends.run_health_check(true).await; @@ -485,4 +567,46 @@ mod test { assert!(backends.ready(&good2)); assert!(!backends.ready(&bad)); } + + mod thread_safety { + use super::*; + + struct MockDiscovery { + expected: usize, + } + #[async_trait] + impl ServiceDiscovery for MockDiscovery { + async fn discover(&self) -> Result<(BTreeSet, HashMap)> { + let mut d = BTreeSet::new(); + let mut m = HashMap::with_capacity(self.expected); + for i in 0..self.expected { + let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap(); + m.insert(i as u64, true); + d.insert(b); + } + Ok((d, m)) + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_consistency() { + let expected = 3000; + let discovery = MockDiscovery { expected }; + let lb = Arc::new(LoadBalancer::::from_backends( + Backends::new(Box::new(discovery)), + )); + let lb2 = lb.clone(); + + tokio::spawn(async move { + assert!(lb2.update().await.is_ok()); + }); + let mut backend_count = 0; + while backend_count == 0 { + let backends = lb.backends(); + backend_count = backends.backends.load_full().len(); + } + assert_eq!(backend_count, expected); + assert!(lb.select_with(b"test", 1, |_, _| true).is_some()); + } + } } diff --git a/pingora-openssl/src/ext.rs b/pingora-openssl/src/ext.rs index 14d248b..6e0ff60 100644 --- a/pingora-openssl/src/ext.rs +++ b/pingora-openssl/src/ext.rs @@ -27,7 +27,7 @@ use openssl_sys::{ use std::ffi::CString; use std::os::raw; -fn cvt(r: c_int) -> Result { +fn cvt(r: c_long) -> Result { if r != 1 { Err(ErrorStack::get()) } else { @@ -42,8 +42,8 @@ extern "C" { namelen: size_t, ) -> c_int; - pub fn SSL_use_certificate(ssl: *const SSL, cert: *mut X509) -> c_int; - pub fn SSL_use_PrivateKey(ctx: *const SSL, key: *mut EVP_PKEY) -> c_int; + pub fn SSL_use_certificate(ssl: *mut SSL, cert: *mut X509) -> c_int; + pub fn SSL_use_PrivateKey(ssl: *mut SSL, key: *mut EVP_PKEY) -> c_int; pub fn SSL_set_cert_cb( ssl: *mut SSL, @@ -64,9 +64,9 @@ pub fn add_host(verify_param: &mut X509VerifyParamRef, host: &str) -> Result<(), unsafe { cvt(X509_VERIFY_PARAM_add1_host( verify_param.as_ptr(), - host.as_ptr() as *const _, + host.as_ptr() as *const c_char, host.len(), - )) + ) as c_long) .map(|_| ()) } } @@ -84,7 +84,7 @@ pub fn ssl_set_verify_cert_store( SSL_CTRL_SET_VERIFY_CERT_STORE, 1, // increase the ref count of X509Store so that ssl_ctx can outlive X509StoreRef cert_store.as_ptr() as *mut c_void, - ) as i32)?; + ))?; } Ok(()) } @@ -94,7 +94,7 @@ pub fn ssl_set_verify_cert_store( /// See [SSL_use_certificate](https://www.openssl.org/docs/man1.1.1/man3/SSL_use_certificate.html). pub fn ssl_use_certificate(ssl: &mut SslRef, cert: &X509Ref) -> Result<(), ErrorStack> { unsafe { - cvt(SSL_use_certificate(ssl.as_ptr(), cert.as_ptr()))?; + cvt(SSL_use_certificate(ssl.as_ptr() as *mut SSL, cert.as_ptr() as *mut X509) as c_long)?; } Ok(()) } @@ -107,7 +107,7 @@ where T: HasPrivate, { unsafe { - cvt(SSL_use_PrivateKey(ssl.as_ptr(), key.as_ptr()))?; + cvt(SSL_use_PrivateKey(ssl.as_ptr() as *mut SSL, key.as_ptr() as *mut EVP_PKEY) as c_long)?; } Ok(()) } @@ -123,7 +123,7 @@ pub fn ssl_add_chain_cert(ssl: &mut SslRef, cert: &X509Ref) -> Result<(), ErrorS SSL_CTRL_CHAIN_CERT, 1, // increase the ref count of X509 so that ssl can outlive X509StoreRef cert.as_ptr() as *mut c_void, - ) as i32)?; + ))?; } Ok(()) } @@ -137,14 +137,17 @@ pub fn ssl_set_renegotiate_mode_freely(_ssl: &mut SslRef) {} /// /// See [set_groups_list](https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set1_curves.html). pub fn ssl_set_groups_list(ssl: &mut SslRef, groups: &str) -> Result<(), ErrorStack> { - let groups = CString::new(groups).unwrap(); + if groups.contains('\0') { + return Err(ErrorStack::get()); + } + let groups = CString::new(groups).map_err(|_| ErrorStack::get())?; unsafe { cvt(SSL_ctrl( ssl.as_ptr(), SSL_CTRL_SET_GROUPS_LIST, 0, groups.as_ptr() as *mut c_void, - ) as i32)?; + ))?; } Ok(()) } @@ -207,3 +210,22 @@ pub fn is_suspended_for_cert(error: &openssl::ssl::Error) -> bool { pub unsafe fn ssl_mut(ssl: &SslRef) -> &mut SslRef { SslRef::from_ptr_mut(ssl.as_ptr()) } + +#[cfg(test)] +mod tests { + use super::*; + use openssl::ssl::{SslContextBuilder, SslMethod}; + + #[test] + fn test_ssl_set_groups_list() { + let ctx_builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); + let ssl = Ssl::new(&ctx_builder.build()).unwrap(); + let ssl_ref = unsafe { ssl_mut(&ssl) }; + + // Valid input + assert!(ssl_set_groups_list(ssl_ref, "P-256:P-384").is_ok()); + + // Invalid input (contains null byte) + assert!(ssl_set_groups_list(ssl_ref, "P-256\0P-384").is_err()); + } +} diff --git a/pingora-proxy/Cargo.toml b/pingora-proxy/Cargo.toml index 65f7f67..0a52dff 100644 --- a/pingora-proxy/Cargo.toml +++ b/pingora-proxy/Cargo.toml @@ -55,3 +55,10 @@ serde_yaml = "0.8" default = ["openssl"] openssl = ["pingora-core/openssl", "pingora-cache/openssl"] boringssl = ["pingora-core/boringssl", "pingora-cache/boringssl"] + +# or locally cargo doc --config "build.rustdocflags='--cfg doc_async_trait'" +[package.metadata.docs.rs] +rustdoc-args = ["--cfg", "doc_async_trait"] + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(doc_async_trait)'] } \ No newline at end of file diff --git a/pingora-proxy/src/lib.rs b/pingora-proxy/src/lib.rs index a02b23f..eebf65d 100644 --- a/pingora-proxy/src/lib.rs +++ b/pingora-proxy/src/lib.rs @@ -35,9 +35,6 @@ //! //! See `examples/load_balancer.rs` for a detailed example. -// enable nightly feature async trait so that the docs are cleaner -#![cfg_attr(doc_async_trait, feature(async_fn_in_trait))] - use async_trait::async_trait; use bytes::Bytes; use futures::future::FutureExt; diff --git a/pingora-timeout/Cargo.toml b/pingora-timeout/Cargo.toml index 92e72e7..1bca525 100644 --- a/pingora-timeout/Cargo.toml +++ b/pingora-timeout/Cargo.toml @@ -24,7 +24,6 @@ tokio = { workspace = true, features = [ "sync", ] } pin-project-lite = "0.2" -futures = "0.3" once_cell = { workspace = true } parking_lot = "0.12" thread_local = "1.0" diff --git a/pingora-timeout/src/fast_timeout.rs b/pingora-timeout/src/fast_timeout.rs index c3b251a..d5de5f4 100644 --- a/pingora-timeout/src/fast_timeout.rs +++ b/pingora-timeout/src/fast_timeout.rs @@ -50,7 +50,7 @@ fn check_clock_thread(tm: &Arc) { pub struct FastTimeout(Duration); impl ToTimeout for FastTimeout { - fn timeout(&self) -> BoxFuture<'static, ()> { + fn timeout(&self) -> Pin + Send + Sync>> { Box::pin(TIMER_MANAGER.register_timer(self.0).poll()) } diff --git a/pingora-timeout/src/lib.rs b/pingora-timeout/src/lib.rs index 75b0663..26bb5b9 100644 --- a/pingora-timeout/src/lib.rs +++ b/pingora-timeout/src/lib.rs @@ -39,7 +39,6 @@ pub mod timer; pub use fast_timeout::fast_sleep as sleep; pub use fast_timeout::fast_timeout as timeout; -use futures::future::BoxFuture; use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; @@ -50,7 +49,7 @@ use tokio::time::{sleep as tokio_sleep, Duration}; /// /// Users don't need to interact with this trait pub trait ToTimeout { - fn timeout(&self) -> BoxFuture<'static, ()>; + fn timeout(&self) -> Pin + Send + Sync>>; fn create(d: Duration) -> Self; } @@ -60,7 +59,7 @@ pub trait ToTimeout { pub struct TokioTimeout(Duration); impl ToTimeout for TokioTimeout { - fn timeout(&self) -> BoxFuture<'static, ()> { + fn timeout(&self) -> Pin + Send + Sync>> { Box::pin(tokio_sleep(self.0)) } @@ -100,7 +99,7 @@ pin_project! { #[pin] value: T, #[pin] - delay: Option>, + delay: Option + Send + Sync>>>, callback: F, // callback to create the timer } } diff --git a/pingora/Cargo.toml b/pingora/Cargo.toml index 75baf1b..7f5e101 100644 --- a/pingora/Cargo.toml +++ b/pingora/Cargo.toml @@ -63,3 +63,4 @@ boringssl = [ proxy = ["pingora-proxy"] lb = ["pingora-load-balancing", "proxy"] cache = ["pingora-cache"] +time = []