mirror of
https://github.com/cloudflare/pingora.git
synced 2024-09-20 02:31:35 +02:00
Merge branch 'main' into patch-1
This commit is contained in:
commit
1ee9b327b4
29 changed files with 468 additions and 147 deletions
2
.bleep
2
.bleep
|
@ -1 +1 @@
|
|||
837db6c7ec2d37abf83f9588be99fda00e2012c3
|
||||
c90e4ce2596840c60b5ff1737e2141447e5953e1
|
||||
|
|
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
|
@ -46,7 +46,7 @@ jobs:
|
|||
|
||||
- name: Run cargo clippy
|
||||
run: |
|
||||
[[ ${{ matrix.toolchain }} == nightly ]] || cargo clippy --all-targets --all -- --deny=warnings
|
||||
[[ ${{ matrix.toolchain }} == nightly ]] || cargo clippy --all-targets --all -- --allow=unknown-lints --deny=warnings
|
||||
|
||||
- name: Run cargo audit
|
||||
run: |
|
||||
|
|
|
@ -32,6 +32,7 @@ use std::time::SystemTime;
|
|||
///
|
||||
/// - Space optimized in-memory LRU (see [pingora_lru]).
|
||||
/// - Instead of a single giant LRU, this struct shards the assets into `N` independent LRUs.
|
||||
///
|
||||
/// This allows [EvictionManager::save()] not to lock the entire cache manager while performing
|
||||
/// serialization.
|
||||
pub struct Manager<const N: usize>(Lru<CompactCacheKey, N>);
|
||||
|
|
|
@ -28,7 +28,7 @@ use crate::protocols::Stream;
|
|||
use crate::server::ShutdownWatch;
|
||||
|
||||
/// This trait defines how to map a request to a response
|
||||
#[cfg_attr(not(doc_async_trait), async_trait)]
|
||||
#[async_trait]
|
||||
pub trait ServeHttp {
|
||||
/// Define the mapping from a request to a response.
|
||||
/// Note that the request header is already read, but the implementation needs to read the
|
||||
|
@ -42,7 +42,7 @@ pub trait ServeHttp {
|
|||
}
|
||||
|
||||
// TODO: remove this in favor of HttpServer?
|
||||
#[cfg_attr(not(doc_async_trait), async_trait)]
|
||||
#[async_trait]
|
||||
impl<SV> HttpServerApp for SV
|
||||
where
|
||||
SV: ServeHttp + Send + Sync,
|
||||
|
@ -128,7 +128,7 @@ impl<SV> HttpServer<SV> {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(not(doc_async_trait), async_trait)]
|
||||
#[async_trait]
|
||||
impl<SV> HttpServerApp for HttpServer<SV>
|
||||
where
|
||||
SV: ServeHttp + Send + Sync,
|
||||
|
|
|
@ -28,7 +28,7 @@ use crate::protocols::Digest;
|
|||
use crate::protocols::Stream;
|
||||
use crate::protocols::ALPN;
|
||||
|
||||
#[cfg_attr(not(doc_async_trait), async_trait)]
|
||||
#[async_trait]
|
||||
/// This trait defines the interface of a transport layer (TCP or TLS) application.
|
||||
pub trait ServerApp {
|
||||
/// Whenever a new connection is established, this function will be called with the established
|
||||
|
@ -62,7 +62,7 @@ pub struct HttpServerOptions {
|
|||
}
|
||||
|
||||
/// This trait defines the interface of an HTTP application.
|
||||
#[cfg_attr(not(doc_async_trait), async_trait)]
|
||||
#[async_trait]
|
||||
pub trait HttpServerApp {
|
||||
/// Similar to the [`ServerApp`], this function is called whenever a new HTTP session is established.
|
||||
///
|
||||
|
@ -95,7 +95,7 @@ pub trait HttpServerApp {
|
|||
async fn http_cleanup(&self) {}
|
||||
}
|
||||
|
||||
#[cfg_attr(not(doc_async_trait), async_trait)]
|
||||
#[async_trait]
|
||||
impl<T> ServerApp for T
|
||||
where
|
||||
T: HttpServerApp + Send + Sync + 'static,
|
||||
|
|
|
@ -29,7 +29,7 @@ use crate::protocols::http::ServerSession;
|
|||
/// collected via the [Prometheus](https://docs.rs/prometheus/) crate;
|
||||
pub struct PrometheusHttpApp;
|
||||
|
||||
#[cfg_attr(not(doc_async_trait), async_trait)]
|
||||
#[async_trait]
|
||||
impl ServeHttp for PrometheusHttpApp {
|
||||
async fn response(&self, _http_session: &mut ServerSession) -> Response<Vec<u8>> {
|
||||
let encoder = TextEncoder::new();
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use log::debug;
|
||||
use pingora_error::{Context, Error, ErrorType::*, OrErr, Result};
|
||||
use rand::seq::SliceRandom;
|
||||
|
@ -26,6 +27,12 @@ use crate::protocols::l4::stream::Stream;
|
|||
use crate::protocols::{GetSocketDigest, SocketDigest};
|
||||
use crate::upstreams::peer::Peer;
|
||||
|
||||
/// The interface to establish a L4 connection
|
||||
#[async_trait]
|
||||
pub trait Connect: std::fmt::Debug {
|
||||
async fn connect(&self, addr: &SocketAddr) -> Result<Stream>;
|
||||
}
|
||||
|
||||
/// Establish a connection (l4) to the given peer using its settings and an optional bind address.
|
||||
pub async fn connect<P>(peer: &P, bind_to: Option<InetSocketAddr>) -> Result<Stream>
|
||||
where
|
||||
|
@ -37,7 +44,11 @@ where
|
|||
.err_context(|| format!("Fail to establish CONNECT proxy: {}", peer));
|
||||
}
|
||||
let peer_addr = peer.address();
|
||||
let mut stream: Stream = match peer_addr {
|
||||
let mut stream: Stream =
|
||||
if let Some(custom_l4) = peer.get_peer_options().and_then(|o| o.custom_l4.as_ref()) {
|
||||
custom_l4.connect(peer_addr).await?
|
||||
} else {
|
||||
match peer_addr {
|
||||
SocketAddr::Inet(addr) => {
|
||||
let connect_future = tcp_connect(addr, bind_to.as_ref(), |socket| {
|
||||
if peer.tcp_fast_open() {
|
||||
|
@ -102,7 +113,9 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
}?;
|
||||
}?
|
||||
};
|
||||
|
||||
let tracer = peer.get_tracer();
|
||||
if let Some(t) = tracer {
|
||||
t.0.on_connected();
|
||||
|
@ -249,6 +262,29 @@ mod tests {
|
|||
assert_eq!(new_session.unwrap_err().etype(), &ConnectTimedout)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_custom_connect() {
|
||||
#[derive(Debug)]
|
||||
struct MyL4;
|
||||
#[async_trait]
|
||||
impl Connect for MyL4 {
|
||||
async fn connect(&self, _addr: &SocketAddr) -> Result<Stream> {
|
||||
tokio::net::TcpStream::connect("1.1.1.1:80")
|
||||
.await
|
||||
.map(|s| s.into())
|
||||
.or_fail()
|
||||
}
|
||||
}
|
||||
// :79 shouldn't be able to be connected to
|
||||
let mut peer = BasicPeer::new("1.1.1.1:79");
|
||||
peer.options.custom_l4 = Some(std::sync::Arc::new(MyL4 {}));
|
||||
|
||||
let new_session = connect(&peer, None).await;
|
||||
|
||||
// but MyL4 connects to :80 instead
|
||||
assert!(new_session.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connect_proxy_fail() {
|
||||
let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());
|
||||
|
|
|
@ -25,6 +25,7 @@ use crate::tls::ssl::SslConnector;
|
|||
use crate::upstreams::peer::{Peer, ALPN};
|
||||
|
||||
use l4::connect as l4_connect;
|
||||
pub use l4::Connect as L4Connect;
|
||||
use log::{debug, error, warn};
|
||||
use offload::OffloadRuntime;
|
||||
use parking_lot::RwLock;
|
||||
|
|
|
@ -18,8 +18,6 @@
|
|||
#![allow(clippy::match_wild_err_arm)]
|
||||
#![allow(clippy::missing_safety_doc)]
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
// enable nightly feature async trait so that the docs are cleaner
|
||||
#![cfg_attr(doc_async_trait, feature(async_fn_in_trait))]
|
||||
|
||||
//! # Pingora
|
||||
//!
|
||||
|
|
|
@ -218,6 +218,19 @@ impl Session {
|
|||
}
|
||||
}
|
||||
|
||||
/// Sets whether we ignore writing informational responses downstream.
|
||||
///
|
||||
/// For HTTP/1.1 this is a noop if the response is Upgrade or Continue and
|
||||
/// Expect: 100-continue was set on the request.
|
||||
///
|
||||
/// This is a noop for h2 because informational responses are always ignored.
|
||||
pub fn set_ignore_info_resp(&mut self, ignore: bool) {
|
||||
match self {
|
||||
Self::H1(s) => s.set_ignore_info_resp(ignore),
|
||||
Self::H2(_) => {} // always ignored
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a digest of the request including the method, path and Host header
|
||||
// TODO: make this use a `Formatter`
|
||||
pub fn request_summary(&self) -> String {
|
||||
|
|
|
@ -153,6 +153,14 @@ pub(super) fn is_upgrade_req(req: &RequestHeader) -> bool {
|
|||
req.version == http::Version::HTTP_11 && req.headers.get(header::UPGRADE).is_some()
|
||||
}
|
||||
|
||||
pub(super) fn is_expect_continue_req(req: &RequestHeader) -> bool {
|
||||
req.version == http::Version::HTTP_11
|
||||
// https://www.rfc-editor.org/rfc/rfc9110#section-10.1.1
|
||||
&& req.headers.get(header::EXPECT).map_or(false, |v| {
|
||||
v.as_bytes().eq_ignore_ascii_case(b"100-continue")
|
||||
})
|
||||
}
|
||||
|
||||
// Unlike the upgrade check on request, this function doesn't check the Upgrade or Connection header
|
||||
// because when seeing 101, we assume the server accepts to switch protocol.
|
||||
// In reality it is not common that some servers don't send all the required headers to establish
|
||||
|
|
|
@ -74,6 +74,8 @@ pub struct HttpSession {
|
|||
digest: Box<Digest>,
|
||||
/// Minimum send rate to the client
|
||||
min_send_rate: Option<usize>,
|
||||
/// When this is enabled informational response headers will not be proxied downstream
|
||||
ignore_info_resp: bool,
|
||||
}
|
||||
|
||||
impl HttpSession {
|
||||
|
@ -109,6 +111,7 @@ impl HttpSession {
|
|||
upgraded: false,
|
||||
digest,
|
||||
min_send_rate: None,
|
||||
ignore_info_resp: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -388,6 +391,11 @@ impl HttpSession {
|
|||
/// Write the response header to the client.
|
||||
/// This function can be called more than once to send 1xx informational headers excluding 101.
|
||||
pub async fn write_response_header(&mut self, mut header: Box<ResponseHeader>) -> Result<()> {
|
||||
if header.status.is_informational() && self.ignore_info_resp(header.status.into()) {
|
||||
debug!("ignoring informational headers");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(resp) = self.response_written.as_ref() {
|
||||
if !resp.status.is_informational() || self.upgraded {
|
||||
warn!("Respond header is already sent, cannot send again");
|
||||
|
@ -409,7 +417,7 @@ impl HttpSession {
|
|||
header.insert_header(header::CONNECTION, connection_value)?;
|
||||
}
|
||||
|
||||
if header.status.as_u16() == 101 {
|
||||
if header.status == 101 {
|
||||
// make sure the connection is closed at the end when 101/upgrade is used
|
||||
self.set_keepalive(None);
|
||||
}
|
||||
|
@ -510,6 +518,18 @@ impl HttpSession {
|
|||
(None, None)
|
||||
}
|
||||
|
||||
fn ignore_info_resp(&self, status: u16) -> bool {
|
||||
// ignore informational response if ignore flag is set and it's not an Upgrade and Expect: 100-continue isn't set
|
||||
self.ignore_info_resp && status != 101 && !(status == 100 && self.is_expect_continue_req())
|
||||
}
|
||||
|
||||
fn is_expect_continue_req(&self) -> bool {
|
||||
match self.request_header.as_deref() {
|
||||
Some(req) => is_expect_continue_req(req),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_connection_keepalive(&self) -> Option<bool> {
|
||||
is_buf_keepalive(self.get_header(header::CONNECTION))
|
||||
}
|
||||
|
@ -824,6 +844,14 @@ impl HttpSession {
|
|||
}
|
||||
}
|
||||
|
||||
/// Sets whether we ignore writing informational responses downstream.
|
||||
///
|
||||
/// This is a noop if the response is Upgrade or Continue and
|
||||
/// Expect: 100-continue was set on the request.
|
||||
pub fn set_ignore_info_resp(&mut self, ignore: bool) {
|
||||
self.ignore_info_resp = ignore;
|
||||
}
|
||||
|
||||
/// Return the [Digest] of the connection.
|
||||
pub fn digest(&self) -> &Digest {
|
||||
&self.digest
|
||||
|
@ -1472,6 +1500,75 @@ mod tests_stream {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_informational_ignored() {
|
||||
let wire = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
|
||||
let mock_io = Builder::new().write(wire).build();
|
||||
let mut http_stream = HttpSession::new(Box::new(mock_io));
|
||||
// ignore the 100 Continue
|
||||
http_stream.ignore_info_resp = true;
|
||||
let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap();
|
||||
http_stream
|
||||
.write_response_header_ref(&response_100)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap();
|
||||
response_200.append_header("Foo", "Bar").unwrap();
|
||||
http_stream.update_resp_headers = false;
|
||||
http_stream
|
||||
.write_response_header_ref(&response_200)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_informational_100_not_ignored_if_expect_continue() {
|
||||
let input = b"GET / HTTP/1.1\r\nExpect: 100-continue\r\n\r\n";
|
||||
let output = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
|
||||
|
||||
let mock_io = Builder::new().read(&input[..]).write(output).build();
|
||||
let mut http_stream = HttpSession::new(Box::new(mock_io));
|
||||
http_stream.read_request().await.unwrap();
|
||||
http_stream.ignore_info_resp = true;
|
||||
// 100 Continue is not ignored due to Expect: 100-continue on request
|
||||
let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap();
|
||||
http_stream
|
||||
.write_response_header_ref(&response_100)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap();
|
||||
response_200.append_header("Foo", "Bar").unwrap();
|
||||
http_stream.update_resp_headers = false;
|
||||
http_stream
|
||||
.write_response_header_ref(&response_200)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_informational_1xx_ignored_if_expect_continue() {
|
||||
let input = b"GET / HTTP/1.1\r\nExpect: 100-continue\r\n\r\n";
|
||||
let output = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
|
||||
|
||||
let mock_io = Builder::new().read(&input[..]).write(output).build();
|
||||
let mut http_stream = HttpSession::new(Box::new(mock_io));
|
||||
http_stream.read_request().await.unwrap();
|
||||
http_stream.ignore_info_resp = true;
|
||||
// 102 Processing is ignored
|
||||
let response_102 = ResponseHeader::build(StatusCode::PROCESSING, None).unwrap();
|
||||
http_stream
|
||||
.write_response_header_ref(&response_102)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap();
|
||||
response_200.append_header("Foo", "Bar").unwrap();
|
||||
http_stream.update_resp_headers = false;
|
||||
http_stream
|
||||
.write_response_header_ref(&response_200)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_101_switching_protocol() {
|
||||
let wire = b"HTTP/1.1 101 Switching Protocols\r\nFoo: Bar\r\n\r\n";
|
||||
|
|
|
@ -194,21 +194,20 @@ impl HttpSession {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
// FIXME: we should ignore 1xx header because send_response() can only be called once
|
||||
if header.status.is_informational() {
|
||||
// ignore informational response 1xx header because send_response() can only be called once
|
||||
// https://github.com/hyperium/h2/issues/167
|
||||
|
||||
if let Some(resp) = self.response_written.as_ref() {
|
||||
if !resp.status.is_informational() {
|
||||
warn!("Respond header is already sent, cannot send again");
|
||||
debug!("ignoring informational headers");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if self.response_written.as_ref().is_some() {
|
||||
warn!("Response header is already sent, cannot send again");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// no need to add these headers to 1xx responses
|
||||
if !header.status.is_informational() {
|
||||
/* update headers */
|
||||
header.insert_header(header::DATE, get_cached_date())?;
|
||||
}
|
||||
|
||||
// remove other h1 hop headers that cannot be present in H2
|
||||
// https://httpwg.org/specs/rfc7540.html#n-connection-specific-header-fields
|
||||
|
@ -486,7 +485,8 @@ mod test {
|
|||
expected_trailers.insert("test", HeaderValue::from_static("trailers"));
|
||||
let trailers = expected_trailers.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut handles = vec![];
|
||||
handles.push(tokio::spawn(async move {
|
||||
let (h2, connection) = h2::client::handshake(client).await.unwrap();
|
||||
tokio::spawn(async move {
|
||||
connection.await.unwrap();
|
||||
|
@ -510,7 +510,7 @@ mod test {
|
|||
assert_eq!(data, server_body);
|
||||
let resp_trailers = body.trailers().await.unwrap().unwrap();
|
||||
assert_eq!(resp_trailers, expected_trailers);
|
||||
});
|
||||
}));
|
||||
|
||||
let mut connection = handshake(Box::new(server), None).await.unwrap();
|
||||
let digest = Arc::new(Digest::default());
|
||||
|
@ -520,7 +520,7 @@ mod test {
|
|||
.unwrap()
|
||||
{
|
||||
let trailers = trailers.clone();
|
||||
tokio::spawn(async move {
|
||||
handles.push(tokio::spawn(async move {
|
||||
let req = http.req_header();
|
||||
assert_eq!(req.method, Method::GET);
|
||||
assert_eq!(req.uri, "https://www.example.com/");
|
||||
|
@ -545,7 +545,11 @@ mod test {
|
|||
}
|
||||
|
||||
let response_header = Box::new(ResponseHeader::build(200, None).unwrap());
|
||||
http.write_response_header(response_header, false).unwrap();
|
||||
assert!(http
|
||||
.write_response_header(response_header.clone(), false)
|
||||
.is_ok());
|
||||
// this write should be ignored otherwise we will error
|
||||
assert!(http.write_response_header(response_header, false).is_ok());
|
||||
|
||||
// test idling after response header is sent
|
||||
tokio::select! {
|
||||
|
@ -559,7 +563,11 @@ mod test {
|
|||
|
||||
http.write_trailers(trailers).unwrap();
|
||||
http.finish().unwrap();
|
||||
});
|
||||
}));
|
||||
}
|
||||
for handle in handles {
|
||||
// ensure no panics
|
||||
assert!(handle.await.is_ok());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -369,14 +369,11 @@ fn wrap_os_connect_error(e: std::io::Error, context: String) -> Box<Error> {
|
|||
Error::because(InternalError, context, e)
|
||||
}
|
||||
_ => match e.raw_os_error() {
|
||||
Some(code) => match code {
|
||||
libc::ENETUNREACH | libc::EHOSTUNREACH => {
|
||||
Some(libc::ENETUNREACH | libc::EHOSTUNREACH) => {
|
||||
Error::because(ConnectNoRoute, context, e)
|
||||
}
|
||||
_ => Error::because(ConnectError, context, e),
|
||||
},
|
||||
None => Error::because(ConnectError, context, e),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ use super::Service;
|
|||
use crate::server::{ListenFds, ShutdownWatch};
|
||||
|
||||
/// The background service interface
|
||||
#[cfg_attr(not(doc_async_trait), async_trait)]
|
||||
#[async_trait]
|
||||
pub trait BackgroundService {
|
||||
/// This function is called when the pingora server tries to start all the
|
||||
/// services. The background service can return at anytime or wait for the
|
||||
|
|
|
@ -29,6 +29,7 @@ use std::path::{Path, PathBuf};
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::connectors::L4Connect;
|
||||
use crate::protocols::l4::socket::SocketAddr;
|
||||
use crate::protocols::ConnFdReusable;
|
||||
use crate::protocols::TcpKeepalive;
|
||||
|
@ -322,6 +323,8 @@ pub struct PeerOptions {
|
|||
pub tcp_fast_open: bool,
|
||||
// use Arc because Clone is required but not allowed in trait object
|
||||
pub tracer: Option<Tracer>,
|
||||
// A custom L4 connector to use to establish new L4 connections
|
||||
pub custom_l4: Option<Arc<dyn L4Connect + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl PeerOptions {
|
||||
|
@ -350,6 +353,7 @@ impl PeerOptions {
|
|||
second_keyshare: true, // default true and noop when not using PQ curves
|
||||
tcp_fast_open: false,
|
||||
tracer: None,
|
||||
custom_l4: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,8 @@ rand = "0"
|
|||
tokio = { workspace = true }
|
||||
futures = "0"
|
||||
log = { workspace = true }
|
||||
http = { workspace = true }
|
||||
derivative = "2.2.0"
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use http::Extensions;
|
||||
use pingora_core::protocols::l4::socket::SocketAddr;
|
||||
use pingora_error::Result;
|
||||
use std::io::Result as IoResult;
|
||||
|
@ -62,6 +63,7 @@ impl Static {
|
|||
let addrs = addrs.to_socket_addrs()?.map(|addr| Backend {
|
||||
addr: SocketAddr::Inet(addr),
|
||||
weight: 1,
|
||||
ext: Extensions::new(),
|
||||
});
|
||||
upstreams.extend(addrs);
|
||||
}
|
||||
|
|
|
@ -315,6 +315,7 @@ impl Health {
|
|||
mod test {
|
||||
use super::*;
|
||||
use crate::SocketAddr;
|
||||
use http::Extensions;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tcp_check() {
|
||||
|
@ -323,6 +324,7 @@ mod test {
|
|||
let backend = Backend {
|
||||
addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
|
||||
weight: 1,
|
||||
ext: Extensions::new(),
|
||||
};
|
||||
|
||||
assert!(tcp_check.check(&backend).await.is_ok());
|
||||
|
@ -330,6 +332,7 @@ mod test {
|
|||
let backend = Backend {
|
||||
addr: SocketAddr::Inet("1.1.1.1:79".parse().unwrap()),
|
||||
weight: 1,
|
||||
ext: Extensions::new(),
|
||||
};
|
||||
|
||||
assert!(tcp_check.check(&backend).await.is_err());
|
||||
|
@ -341,6 +344,7 @@ mod test {
|
|||
let backend = Backend {
|
||||
addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
|
||||
weight: 1,
|
||||
ext: Extensions::new(),
|
||||
};
|
||||
|
||||
assert!(tls_check.check(&backend).await.is_ok());
|
||||
|
@ -353,6 +357,7 @@ mod test {
|
|||
let backend = Backend {
|
||||
addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
|
||||
weight: 1,
|
||||
ext: Extensions::new(),
|
||||
};
|
||||
|
||||
assert!(https_check.check(&backend).await.is_ok());
|
||||
|
@ -375,6 +380,7 @@ mod test {
|
|||
let backend = Backend {
|
||||
addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
|
||||
weight: 1,
|
||||
ext: Extensions::new(),
|
||||
};
|
||||
|
||||
http_check.check(&backend).await.unwrap();
|
||||
|
|
|
@ -16,8 +16,14 @@
|
|||
//! This crate provides common service discovery, health check and load balancing
|
||||
//! algorithms for proxies to use.
|
||||
|
||||
// https://github.com/mcarton/rust-derivative/issues/112
|
||||
// False positive for macro generated code
|
||||
#![allow(clippy::non_canonical_partial_ord_impl)]
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use derivative::Derivative;
|
||||
use futures::FutureExt;
|
||||
pub use http::Extensions;
|
||||
use pingora_core::protocols::l4::socket::SocketAddr;
|
||||
use pingora_error::{ErrorType, OrErr, Result};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
|
@ -45,13 +51,26 @@ pub mod prelude {
|
|||
}
|
||||
|
||||
/// [Backend] represents a server to proxy or connect to.
|
||||
#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
|
||||
#[derive(Derivative)]
|
||||
#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
|
||||
pub struct Backend {
|
||||
/// The address to the backend server.
|
||||
pub addr: SocketAddr,
|
||||
/// The relative weight of the server. Load balancing algorithms will
|
||||
/// proportionally distributed traffic according to this value.
|
||||
pub weight: usize,
|
||||
|
||||
/// The extension field to put arbitrary data to annotate the Backend.
|
||||
/// The data added here is opaque to this crate hence the data is ignored by
|
||||
/// functionalities of this crate. For example, two backends with the same
|
||||
/// [SocketAddr] and the same weight but different `ext` data are considered
|
||||
/// identical.
|
||||
/// See [Extensions] for how to add and read the data.
|
||||
#[derivative(PartialEq = "ignore")]
|
||||
#[derivative(PartialOrd = "ignore")]
|
||||
#[derivative(Hash = "ignore")]
|
||||
#[derivative(Ord = "ignore")]
|
||||
pub ext: Extensions,
|
||||
}
|
||||
|
||||
impl Backend {
|
||||
|
@ -137,8 +156,17 @@ impl Backends {
|
|||
self.health_check = Some(hc.into())
|
||||
}
|
||||
|
||||
/// Return true when the new is different from the current set of backends
|
||||
fn do_update(&self, new_backends: BTreeSet<Backend>, enablement: HashMap<u64, bool>) -> bool {
|
||||
/// Updates backends when the new is different from the current set,
|
||||
/// the callback will be invoked when the new set of backend is different
|
||||
/// from the current one so that the caller can update the selector accordingly.
|
||||
fn do_update<F>(
|
||||
&self,
|
||||
new_backends: BTreeSet<Backend>,
|
||||
enablement: HashMap<u64, bool>,
|
||||
callback: F,
|
||||
) where
|
||||
F: Fn(Arc<BTreeSet<Backend>>),
|
||||
{
|
||||
if (**self.backends.load()) != new_backends {
|
||||
let old_health = self.health.load();
|
||||
let mut health = HashMap::with_capacity(new_backends.len());
|
||||
|
@ -154,10 +182,14 @@ impl Backends {
|
|||
health.insert(hash_key, backend_health);
|
||||
}
|
||||
|
||||
// TODO: put backend and health under 1 ArcSwap so that this update is atomic
|
||||
self.backends.store(Arc::new(new_backends));
|
||||
// TODO: put this all under 1 ArcSwap so the update is atomic
|
||||
// It's important the `callback()` executes first since computing selector backends might
|
||||
// be expensive. For example, if a caller checks `backends` to see if any are available
|
||||
// they may encounter false positives if the selector isn't ready yet.
|
||||
let new_backends = Arc::new(new_backends);
|
||||
callback(new_backends.clone());
|
||||
self.backends.store(new_backends);
|
||||
self.health.store(Arc::new(health));
|
||||
true
|
||||
} else {
|
||||
// no backend change, just check enablement
|
||||
for (hash_key, backend_enabled) in enablement.iter() {
|
||||
|
@ -167,7 +199,6 @@ impl Backends {
|
|||
backend_health.enable(*backend_enabled);
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -206,12 +237,15 @@ impl Backends {
|
|||
|
||||
/// Call the service discovery method to update the collection of backends.
|
||||
///
|
||||
/// Return `true` when the new collection is different from the current set of backends.
|
||||
/// This return value is useful to tell the caller when to rebuild things that are expensive to
|
||||
/// update, such as consistent hashing rings.
|
||||
pub async fn update(&self) -> Result<bool> {
|
||||
/// The callback will be invoked when the new set of backend is different
|
||||
/// from the current one so that the caller can update the selector accordingly.
|
||||
pub async fn update<F>(&self, callback: F) -> Result<()>
|
||||
where
|
||||
F: Fn(Arc<BTreeSet<Backend>>),
|
||||
{
|
||||
let (new_backends, enablement) = self.discovery.discover().await?;
|
||||
Ok(self.do_update(new_backends, enablement))
|
||||
self.do_update(new_backends, enablement, callback);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run health check on all backends if it is set.
|
||||
|
@ -327,11 +361,9 @@ where
|
|||
/// This function will be called every `update_frequency` if this [LoadBalancer] instance
|
||||
/// is running as a background service.
|
||||
pub async fn update(&self) -> Result<()> {
|
||||
if self.backends.update().await? {
|
||||
self.selector
|
||||
.store(Arc::new(S::build(&self.backends.get_backend())))
|
||||
}
|
||||
Ok(())
|
||||
self.backends
|
||||
.update(|backends| self.selector.store(Arc::new(S::build(&backends))))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Return the first healthy [Backend] according to the selection algorithm and the
|
||||
|
@ -385,6 +417,8 @@ where
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
|
||||
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
|
||||
|
@ -415,10 +449,20 @@ mod test {
|
|||
backends.set_health_check(check);
|
||||
|
||||
// true: new backend discovered
|
||||
assert!(backends.update().await.unwrap());
|
||||
let updated = AtomicBool::new(false);
|
||||
backends
|
||||
.update(|_| updated.store(true, Relaxed))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(updated.load(Relaxed));
|
||||
|
||||
// false: no new backend discovered
|
||||
assert!(!backends.update().await.unwrap());
|
||||
let updated = AtomicBool::new(false);
|
||||
backends
|
||||
.update(|_| updated.store(true, Relaxed))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!updated.load(Relaxed));
|
||||
|
||||
backends.run_health_check(false).await;
|
||||
|
||||
|
@ -431,6 +475,31 @@ mod test {
|
|||
assert!(backends.ready(&good2));
|
||||
assert!(!backends.ready(&bad));
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn test_backends_with_ext() {
|
||||
let discovery = discovery::Static::default();
|
||||
let mut b1 = Backend::new("1.1.1.1:80").unwrap();
|
||||
b1.ext.insert(true);
|
||||
let mut b2 = Backend::new("1.0.0.1:80").unwrap();
|
||||
b2.ext.insert(1u8);
|
||||
discovery.add(b1.clone());
|
||||
discovery.add(b2.clone());
|
||||
|
||||
let backends = Backends::new(Box::new(discovery));
|
||||
|
||||
// fill in the backends
|
||||
backends.update(|_| {}).await.unwrap();
|
||||
|
||||
let backend = backends.get_backend();
|
||||
assert!(backend.contains(&b1));
|
||||
assert!(backend.contains(&b2));
|
||||
|
||||
let b2 = backend.first().unwrap();
|
||||
assert_eq!(b2.ext.get::<u8>(), Some(&1));
|
||||
|
||||
let b1 = backend.last().unwrap();
|
||||
assert_eq!(b1.ext.get::<bool>(), Some(&true));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_discovery_readiness() {
|
||||
|
@ -456,7 +525,14 @@ mod test {
|
|||
let discovery = TestDiscovery(discovery);
|
||||
|
||||
let backends = Backends::new(Box::new(discovery));
|
||||
assert!(backends.update().await.unwrap());
|
||||
|
||||
// true: new backend discovered
|
||||
let updated = AtomicBool::new(false);
|
||||
backends
|
||||
.update(|_| updated.store(true, Relaxed))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(updated.load(Relaxed));
|
||||
|
||||
let backend = backends.get_backend();
|
||||
assert!(backend.contains(&good1));
|
||||
|
@ -483,7 +559,12 @@ mod test {
|
|||
backends.set_health_check(check);
|
||||
|
||||
// true: new backend discovered
|
||||
assert!(backends.update().await.unwrap());
|
||||
let updated = AtomicBool::new(false);
|
||||
backends
|
||||
.update(|_| updated.store(true, Relaxed))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(updated.load(Relaxed));
|
||||
|
||||
backends.run_health_check(true).await;
|
||||
|
||||
|
@ -491,4 +572,46 @@ mod test {
|
|||
assert!(backends.ready(&good2));
|
||||
assert!(!backends.ready(&bad));
|
||||
}
|
||||
|
||||
mod thread_safety {
|
||||
use super::*;
|
||||
|
||||
struct MockDiscovery {
|
||||
expected: usize,
|
||||
}
|
||||
#[async_trait]
|
||||
impl ServiceDiscovery for MockDiscovery {
|
||||
async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
|
||||
let mut d = BTreeSet::new();
|
||||
let mut m = HashMap::with_capacity(self.expected);
|
||||
for i in 0..self.expected {
|
||||
let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
|
||||
m.insert(i as u64, true);
|
||||
d.insert(b);
|
||||
}
|
||||
Ok((d, m))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_consistency() {
|
||||
let expected = 3000;
|
||||
let discovery = MockDiscovery { expected };
|
||||
let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
|
||||
Backends::new(Box::new(discovery)),
|
||||
));
|
||||
let lb2 = lb.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
assert!(lb2.update().await.is_ok());
|
||||
});
|
||||
let mut backend_count = 0;
|
||||
while backend_count == 0 {
|
||||
let backends = lb.backends();
|
||||
backend_count = backends.backends.load_full().len();
|
||||
}
|
||||
assert_eq!(backend_count, expected);
|
||||
assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ use openssl_sys::{
|
|||
use std::ffi::CString;
|
||||
use std::os::raw;
|
||||
|
||||
fn cvt(r: c_int) -> Result<c_int, ErrorStack> {
|
||||
fn cvt(r: c_long) -> Result<c_long, ErrorStack> {
|
||||
if r != 1 {
|
||||
Err(ErrorStack::get())
|
||||
} else {
|
||||
|
@ -42,8 +42,8 @@ extern "C" {
|
|||
namelen: size_t,
|
||||
) -> c_int;
|
||||
|
||||
pub fn SSL_use_certificate(ssl: *const SSL, cert: *mut X509) -> c_int;
|
||||
pub fn SSL_use_PrivateKey(ctx: *const SSL, key: *mut EVP_PKEY) -> c_int;
|
||||
pub fn SSL_use_certificate(ssl: *mut SSL, cert: *mut X509) -> c_int;
|
||||
pub fn SSL_use_PrivateKey(ssl: *mut SSL, key: *mut EVP_PKEY) -> c_int;
|
||||
|
||||
pub fn SSL_set_cert_cb(
|
||||
ssl: *mut SSL,
|
||||
|
@ -64,9 +64,9 @@ pub fn add_host(verify_param: &mut X509VerifyParamRef, host: &str) -> Result<(),
|
|||
unsafe {
|
||||
cvt(X509_VERIFY_PARAM_add1_host(
|
||||
verify_param.as_ptr(),
|
||||
host.as_ptr() as *const _,
|
||||
host.as_ptr() as *const c_char,
|
||||
host.len(),
|
||||
))
|
||||
) as c_long)
|
||||
.map(|_| ())
|
||||
}
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ pub fn ssl_set_verify_cert_store(
|
|||
SSL_CTRL_SET_VERIFY_CERT_STORE,
|
||||
1, // increase the ref count of X509Store so that ssl_ctx can outlive X509StoreRef
|
||||
cert_store.as_ptr() as *mut c_void,
|
||||
) as i32)?;
|
||||
))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ pub fn ssl_set_verify_cert_store(
|
|||
/// See [SSL_use_certificate](https://www.openssl.org/docs/man1.1.1/man3/SSL_use_certificate.html).
|
||||
pub fn ssl_use_certificate(ssl: &mut SslRef, cert: &X509Ref) -> Result<(), ErrorStack> {
|
||||
unsafe {
|
||||
cvt(SSL_use_certificate(ssl.as_ptr(), cert.as_ptr()))?;
|
||||
cvt(SSL_use_certificate(ssl.as_ptr() as *mut SSL, cert.as_ptr() as *mut X509) as c_long)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ where
|
|||
T: HasPrivate,
|
||||
{
|
||||
unsafe {
|
||||
cvt(SSL_use_PrivateKey(ssl.as_ptr(), key.as_ptr()))?;
|
||||
cvt(SSL_use_PrivateKey(ssl.as_ptr() as *mut SSL, key.as_ptr() as *mut EVP_PKEY) as c_long)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -123,7 +123,7 @@ pub fn ssl_add_chain_cert(ssl: &mut SslRef, cert: &X509Ref) -> Result<(), ErrorS
|
|||
SSL_CTRL_CHAIN_CERT,
|
||||
1, // increase the ref count of X509 so that ssl can outlive X509StoreRef
|
||||
cert.as_ptr() as *mut c_void,
|
||||
) as i32)?;
|
||||
))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -137,14 +137,17 @@ pub fn ssl_set_renegotiate_mode_freely(_ssl: &mut SslRef) {}
|
|||
///
|
||||
/// See [set_groups_list](https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set1_curves.html).
|
||||
pub fn ssl_set_groups_list(ssl: &mut SslRef, groups: &str) -> Result<(), ErrorStack> {
|
||||
let groups = CString::new(groups).unwrap();
|
||||
if groups.contains('\0') {
|
||||
return Err(ErrorStack::get());
|
||||
}
|
||||
let groups = CString::new(groups).map_err(|_| ErrorStack::get())?;
|
||||
unsafe {
|
||||
cvt(SSL_ctrl(
|
||||
ssl.as_ptr(),
|
||||
SSL_CTRL_SET_GROUPS_LIST,
|
||||
0,
|
||||
groups.as_ptr() as *mut c_void,
|
||||
) as i32)?;
|
||||
))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -207,3 +210,22 @@ pub fn is_suspended_for_cert(error: &openssl::ssl::Error) -> bool {
|
|||
pub unsafe fn ssl_mut(ssl: &SslRef) -> &mut SslRef {
|
||||
SslRef::from_ptr_mut(ssl.as_ptr())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use openssl::ssl::{SslContextBuilder, SslMethod};
|
||||
|
||||
#[test]
|
||||
fn test_ssl_set_groups_list() {
|
||||
let ctx_builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
|
||||
let ssl = Ssl::new(&ctx_builder.build()).unwrap();
|
||||
let ssl_ref = unsafe { ssl_mut(&ssl) };
|
||||
|
||||
// Valid input
|
||||
assert!(ssl_set_groups_list(ssl_ref, "P-256:P-384").is_ok());
|
||||
|
||||
// Invalid input (contains null byte)
|
||||
assert!(ssl_set_groups_list(ssl_ref, "P-256\0P-384").is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,3 +55,10 @@ serde_yaml = "0.8"
|
|||
default = ["openssl"]
|
||||
openssl = ["pingora-core/openssl", "pingora-cache/openssl"]
|
||||
boringssl = ["pingora-core/boringssl", "pingora-cache/boringssl"]
|
||||
|
||||
# or locally cargo doc --config "build.rustdocflags='--cfg doc_async_trait'"
|
||||
[package.metadata.docs.rs]
|
||||
rustdoc-args = ["--cfg", "doc_async_trait"]
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(doc_async_trait)'] }
|
|
@ -35,9 +35,6 @@
|
|||
//!
|
||||
//! See `examples/load_balancer.rs` for a detailed example.
|
||||
|
||||
// enable nightly feature async trait so that the docs are cleaner
|
||||
#![cfg_attr(doc_async_trait, feature(async_fn_in_trait))]
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use futures::future::FutureExt;
|
||||
|
|
|
@ -24,7 +24,6 @@ tokio = { workspace = true, features = [
|
|||
"sync",
|
||||
] }
|
||||
pin-project-lite = "0.2"
|
||||
futures = "0.3"
|
||||
once_cell = { workspace = true }
|
||||
parking_lot = "0.12"
|
||||
thread_local = "1.0"
|
||||
|
|
|
@ -50,7 +50,7 @@ fn check_clock_thread(tm: &Arc<TimerManager>) {
|
|||
pub struct FastTimeout(Duration);
|
||||
|
||||
impl ToTimeout for FastTimeout {
|
||||
fn timeout(&self) -> BoxFuture<'static, ()> {
|
||||
fn timeout(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
|
||||
Box::pin(TIMER_MANAGER.register_timer(self.0).poll())
|
||||
}
|
||||
|
||||
|
|
|
@ -39,7 +39,6 @@ pub mod timer;
|
|||
pub use fast_timeout::fast_sleep as sleep;
|
||||
pub use fast_timeout::fast_timeout as timeout;
|
||||
|
||||
use futures::future::BoxFuture;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
|
@ -50,7 +49,7 @@ use tokio::time::{sleep as tokio_sleep, Duration};
|
|||
///
|
||||
/// Users don't need to interact with this trait
|
||||
pub trait ToTimeout {
|
||||
fn timeout(&self) -> BoxFuture<'static, ()>;
|
||||
fn timeout(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>>;
|
||||
fn create(d: Duration) -> Self;
|
||||
}
|
||||
|
||||
|
@ -60,7 +59,7 @@ pub trait ToTimeout {
|
|||
pub struct TokioTimeout(Duration);
|
||||
|
||||
impl ToTimeout for TokioTimeout {
|
||||
fn timeout(&self) -> BoxFuture<'static, ()> {
|
||||
fn timeout(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
|
||||
Box::pin(tokio_sleep(self.0))
|
||||
}
|
||||
|
||||
|
@ -100,7 +99,7 @@ pin_project! {
|
|||
#[pin]
|
||||
value: T,
|
||||
#[pin]
|
||||
delay: Option<BoxFuture<'static, ()>>,
|
||||
delay: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
|
||||
callback: F, // callback to create the timer
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,3 +63,4 @@ boringssl = [
|
|||
proxy = ["pingora-proxy"]
|
||||
lb = ["pingora-load-balancing", "proxy"]
|
||||
cache = ["pingora-cache"]
|
||||
time = []
|
||||
|
|
Loading…
Reference in a new issue