Merged main branch

This commit is contained in:
Wladimir Palant 2024-08-10 13:06:57 +02:00
commit d6e94c93f1
80 changed files with 1895 additions and 332 deletions

2
.bleep
View file

@ -1 +1 @@
f70d8b77a4085cbe11b9559317f6d6e7e49914db
f123f5e43e9ada31a0e541b917ea674527fd06a3

View file

@ -6,8 +6,7 @@ jobs:
pingora:
strategy:
matrix:
# TODO: add nightly
toolchain: [1.78, 1.72]
toolchain: [nightly, 1.72, 1.80.0]
runs-on: ubuntu-latest
# Only run on "pull_request" event for external PRs. This is to avoid
# duplicate builds for PRs created from internal branches.
@ -46,7 +45,9 @@ jobs:
run: cargo test --verbose --doc
- name: Run cargo clippy
run: cargo clippy --all-targets --all -- --deny=warnings
run: |
[[ ${{ matrix.toolchain }} == nightly ]] || cargo clippy --all-targets --all -- --allow=unknown-lints --deny=warnings
- name: Run cargo audit
uses: actions-rust-lang/audit@v1
run: |
[[ ${{ matrix.toolchain }} == nightly ]] || (cargo install cargo-audit && cargo audit)

View file

@ -2,6 +2,36 @@
All notable changes to this project will be documented in this file.
## [0.3.0](https://github.com/cloudflare/pingora/compare/0.2.0...0.3.0) - 2024-07-12
### 🚀 Features
- Add support for HTTP modules. This feature allows users to import modules written by 3rd parties.
- Add `request_body_filter`. Now request body can be inspected and modified.
- Add H2c support.
- Add TCP fast open support.
- Add support for server side TCP keep-alive.
- Add support to get TCP_INFO.
- Add support to set DSCP.
- Add `or_err()`/`or_err_with` API to convert `Options` to `pingora::Error`.
- Add `or_fail()` API to convert `impl std::error::Error` to `pingora::Error`.
- Add the API to track socket read and write pending time.
- Compression: allow setting level per algorithm.
### 🐛 Bug Fixes
- Fixed a panic when using multiple H2 streams in the same H2 connection to upstreams.
- Pingora now respects the `Connection` header it sends to upstream.
- Accept-Ranges header is now removed when response is compressed.
- Fix ipv6_only socket flag.
- A new H2 connection is opened now if the existing connection returns GOAWAY with graceful shutdown error.
- Fix a FD mismatch error when 0.0.0.0 is used as the upstream IP
### ⚙️ Changes and Miscellaneous Tasks
- Dependency: replace `structopt` with `clap`
- Rework the API of HTTP modules
- Optimize remove_header() API call
- UDS parsing now requires the path to have `unix:` prefix. The support for the path without prefix is deprecated and will be removed on the next release.
- Other minor API changes
## [0.2.0](https://github.com/cloudflare/pingora/compare/0.1.1...0.2.0) - 2024-05-10
### 🚀 Features

View file

@ -237,7 +237,7 @@ take advantage of with single-line change.
```rust
fn main() {
let mut my_server = Server::new(Some(Opt::default())).unwrap();
let mut my_server = Server::new(Some(Opt::parse_args())).unwrap();
...
}
```

View file

@ -20,6 +20,7 @@ In this guide, we will cover the most used features, operations and settings of
* [Examples: take control of the request](modify_filter.md)
* [Connection pooling and reuse](pooling.md)
* [Handling failures and failover](failover.md)
* [RateLimiter quickstart](rate_limiter.md)
## Advanced topics (WIP)
* [Pingora internals](internals.md)

View file

@ -0,0 +1,167 @@
# **RateLimiter quickstart**
Pingora provides a crate `pingora-limits` which provides a simple and easy to use rate limiter for your application. Below is an example of how you can use [`Rate`](https://docs.rs/pingora-limits/latest/pingora_limits/rate/struct.Rate.html) to create an application that uses multiple limiters to restrict the rate at which requests can be made on a per-app basis (determined by a request header).
## Steps
1. Add the following dependencies to your `Cargo.toml`:
```toml
async-trait="0.1"
pingora = { version = "0.3", features = [ "lb" ] }
pingora-limits = "0.3.0"
once_cell = "1.19.0"
```
2. Declare a global rate limiter map to store the rate limiter for each client. In this example, we use `appid`.
3. Override the `request_filter` method in the `ProxyHttp` trait to implement rate limiting.
1. Retrieve the client appid from header.
2. Retrieve the current window requests from the rate limiter map. If there is no rate limiter for the client, create a new one and insert it into the map.
3. If the current window requests exceed the limit, return 429 and set RateLimiter associated headers.
4. If the request is not rate limited, return `Ok(false)` to continue the request.
## Example
```rust
use async_trait::async_trait;
use once_cell::sync::Lazy;
use pingora::http::ResponseHeader;
use pingora::prelude::*;
use pingora_limits::rate::Rate;
use std::sync::Arc;
use std::time::Duration;
fn main() {
let mut server = Server::new(Some(Opt::default())).unwrap();
server.bootstrap();
let mut upstreams = LoadBalancer::try_from_iter(["1.1.1.1:443", "1.0.0.1:443"]).unwrap();
// Set health check
let hc = TcpHealthCheck::new();
upstreams.set_health_check(hc);
upstreams.health_check_frequency = Some(Duration::from_secs(1));
// Set background service
let background = background_service("health check", upstreams);
let upstreams = background.task();
// Set load balancer
let mut lb = http_proxy_service(&server.configuration, LB(upstreams));
lb.add_tcp("0.0.0.0:6188");
// let rate = Rate
server.add_service(background);
server.add_service(lb);
server.run_forever();
}
pub struct LB(Arc<LoadBalancer<RoundRobin>>);
impl LB {
pub fn get_request_appid(&self, session: &mut Session) -> Option<String> {
match session
.req_header()
.headers
.get("appid")
.map(|v| v.to_str())
{
None => None,
Some(v) => match v {
Ok(v) => Some(v.to_string()),
Err(_) => None,
},
}
}
}
// Rate limiter
static RATE_LIMITER: Lazy<Rate> = Lazy::new(|| Rate::new(Duration::from_secs(1)));
// max request per second per client
static MAX_REQ_PER_SEC: isize = 1;
#[async_trait]
impl ProxyHttp for LB {
type CTX = ();
fn new_ctx(&self) {}
async fn upstream_peer(
&self,
_session: &mut Session,
_ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
let upstream = self.0.select(b"", 256).unwrap();
// Set SNI
let peer = Box::new(HttpPeer::new(upstream, true, "one.one.one.one".to_string()));
Ok(peer)
}
async fn upstream_request_filter(
&self,
_session: &mut Session,
upstream_request: &mut RequestHeader,
_ctx: &mut Self::CTX,
) -> Result<()>
where
Self::CTX: Send + Sync,
{
upstream_request
.insert_header("Host", "one.one.one.one")
.unwrap();
Ok(())
}
async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result<bool>
where
Self::CTX: Send + Sync,
{
let appid = match self.get_request_appid(session) {
None => return Ok(false), // no client appid found, skip rate limiting
Some(addr) => addr,
};
// retrieve the current window requests
let curr_window_requests = RATE_LIMITER.observe(&appid, 1);
if curr_window_requests > MAX_REQ_PER_SEC {
// rate limited, return 429
let mut header = ResponseHeader::build(429, None).unwrap();
header
.insert_header("X-Rate-Limit-Limit", MAX_REQ_PER_SEC.to_string())
.unwrap();
header.insert_header("X-Rate-Limit-Remaining", "0").unwrap();
header.insert_header("X-Rate-Limit-Reset", "1").unwrap();
session.set_keepalive(None);
session
.write_response_header(Box::new(header), true)
.await?;
return Ok(true);
}
Ok(false)
}
}
```
## Testing
To use the example above,
1. Run your program with `cargo run`.
2. Verify the program is working with a few executions of ` curl localhost:6188 -H "appid:1" -v`
- The first request should work and any later requests that arrive within 1s of a previous request should fail with:
```
* Trying 127.0.0.1:6188...
* Connected to localhost (127.0.0.1) port 6188 (#0)
> GET / HTTP/1.1
> Host: localhost:6188
> User-Agent: curl/7.88.1
> Accept: */*
> appid:1
>
< HTTP/1.1 429 Too Many Requests
< X-Rate-Limit-Limit: 1
< X-Rate-Limit-Remaining: 0
< X-Rate-Limit-Reset: 1
< Date: Sun, 14 Jul 2024 20:29:02 GMT
< Connection: close
<
* Closing connection 0
```
## Complete Example
You can run the pre-made example code in the [`pingora-proxy` examples folder](https://github.com/cloudflare/pingora/tree/main/pingora-proxy/examples/rate_limiter.rs) with
```
cargo run --example rate_limiter
```

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-boringssl"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-cache"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
@ -17,12 +17,12 @@ name = "pingora_cache"
path = "src/lib.rs"
[dependencies]
pingora-core = { version = "0.2.0", path = "../pingora-core", default-features = false }
pingora-error = { version = "0.2.0", path = "../pingora-error" }
pingora-header-serde = { version = "0.2.0", path = "../pingora-header-serde" }
pingora-http = { version = "0.2.0", path = "../pingora-http" }
pingora-lru = { version = "0.2.0", path = "../pingora-lru" }
pingora-timeout = { version = "0.2.0", path = "../pingora-timeout" }
pingora-core = { version = "0.3.0", path = "../pingora-core", default-features = false }
pingora-error = { version = "0.3.0", path = "../pingora-error" }
pingora-header-serde = { version = "0.3.0", path = "../pingora-header-serde" }
pingora-http = { version = "0.3.0", path = "../pingora-http" }
pingora-lru = { version = "0.3.0", path = "../pingora-lru" }
pingora-timeout = { version = "0.3.0", path = "../pingora-timeout" }
http = { workspace = true }
indexmap = "1"
once_cell = { workspace = true }

View file

@ -32,6 +32,7 @@ use std::time::SystemTime;
///
/// - Space optimized in-memory LRU (see [pingora_lru]).
/// - Instead of a single giant LRU, this struct shards the assets into `N` independent LRUs.
///
/// This allows [EvictionManager::save()] not to lock the entire cache manager while performing
/// serialization.
pub struct Manager<const N: usize>(Lru<CompactCacheKey, N>);

View file

@ -130,7 +130,7 @@ impl CacheKey {
/// Storage optimized cache key to keep in memory or in storage
// 16 bytes + 8 bytes (+16 * u8) + user_tag.len() + 16 Bytes (Box<str>)
#[derive(Debug, Deserialize, Serialize, Clone, Hash, PartialEq, Eq)]
#[derive(Debug, Deserialize, Serialize, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct CompactCacheKey {
pub primary: HashBinary,
// save 8 bytes for non-variance but waste 8 bytes for variance vs, store flat 16 bytes

View file

@ -44,7 +44,7 @@ pub use key::CacheKey;
use lock::{CacheLock, LockStatus, Locked};
pub use memory::MemCache;
pub use meta::{CacheMeta, CacheMetaDefaults};
pub use storage::{HitHandler, MissHandler, Storage};
pub use storage::{HitHandler, MissHandler, PurgeType, Storage};
pub use variance::VarianceBuilder;
pub mod prelude {}
@ -77,6 +77,8 @@ pub enum CachePhase {
Miss,
/// A staled (expired) asset is found
Stale,
/// A staled (expired) asset was found, but another request is revalidating it
StaleUpdating,
/// A staled (expired) asset was found, so a fresh one was fetched
Expired,
/// A staled (expired) asset was found, and it was revalidated to be fresh
@ -96,6 +98,7 @@ impl CachePhase {
CachePhase::Hit => "hit",
CachePhase::Miss => "miss",
CachePhase::Stale => "stale",
CachePhase::StaleUpdating => "stale-updating",
CachePhase::Expired => "expired",
CachePhase::Revalidated => "revalidated",
CachePhase::RevalidatedNoCache(_) => "revalidated-nocache",
@ -260,7 +263,7 @@ impl HttpCache {
use CachePhase::*;
match self.phase {
Disabled(_) | Bypass | Miss | Expired | Revalidated | RevalidatedNoCache(_) => true,
Hit | Stale => false,
Hit | Stale | StaleUpdating => false,
Uninit | CacheKey => false, // invalid states for this call, treat them as false to keep it simple
}
}
@ -344,10 +347,10 @@ impl HttpCache {
/// - `storage`: the cache storage backend that implements [storage::Storage]
/// - `eviction`: optionally the eviction manager, without it, nothing will be evicted from the storage
/// - `predictor`: optionally a cache predictor. The cache predictor predicts whether something is likely
/// to be cacheable or not. This is useful because the proxy can apply different types of optimization to
/// cacheable and uncacheable requests.
/// to be cacheable or not. This is useful because the proxy can apply different types of optimization to
/// cacheable and uncacheable requests.
/// - `cache_lock`: optionally a cache lock which handles concurrent lookups to the same asset. Without it
/// such lookups will all be allowed to fetch the asset independently.
/// such lookups will all be allowed to fetch the asset independently.
pub fn enable(
&mut self,
storage: &'static (dyn storage::Storage + Sync),
@ -493,7 +496,8 @@ impl HttpCache {
match self.phase {
// from CacheKey: set state to miss during cache lookup
// from Bypass: response became cacheable, set state to miss to cache
CachePhase::CacheKey | CachePhase::Bypass => {
// from Stale: waited for cache lock, then retried and found asset was gone
CachePhase::CacheKey | CachePhase::Bypass | CachePhase::Stale => {
self.phase = CachePhase::Miss;
self.inner_mut().traces.start_miss_span();
}
@ -508,6 +512,7 @@ impl HttpCache {
match self.phase {
CachePhase::Hit
| CachePhase::Stale
| CachePhase::StaleUpdating
| CachePhase::Revalidated
| CachePhase::RevalidatedNoCache(_) => self.inner_mut().body_reader.as_mut().unwrap(),
_ => panic!("wrong phase {:?}", self.phase),
@ -543,6 +548,7 @@ impl HttpCache {
| CachePhase::Miss
| CachePhase::Expired
| CachePhase::Stale
| CachePhase::StaleUpdating
| CachePhase::Revalidated
| CachePhase::RevalidatedNoCache(_) => {
let inner = self.inner_mut();
@ -658,7 +664,10 @@ impl HttpCache {
let handle = span.handle();
for item in evicted {
// TODO: warn/log the error
let _ = inner.storage.purge(&item, &handle).await;
let _ = inner
.storage
.purge(&item, PurgeType::Eviction, &handle)
.await;
}
}
inner.traces.finish_miss_span();
@ -782,6 +791,14 @@ impl HttpCache {
// TODO: remove this asset from cache once finished?
}
/// Mark this asset as stale, but being updated separately from this request.
pub fn set_stale_updating(&mut self) {
match self.phase {
CachePhase::Stale => self.phase = CachePhase::StaleUpdating,
_ => panic!("wrong phase {:?}", self.phase),
}
}
/// Update the variance of the [CacheMeta].
///
/// Note that this process may change the lookup `key`, and eventually (when the asset is
@ -850,6 +867,7 @@ impl HttpCache {
match self.phase {
// TODO: allow in Bypass phase?
CachePhase::Stale
| CachePhase::StaleUpdating
| CachePhase::Expired
| CachePhase::Hit
| CachePhase::Revalidated
@ -878,6 +896,7 @@ impl HttpCache {
match self.phase {
CachePhase::Miss
| CachePhase::Stale
| CachePhase::StaleUpdating
| CachePhase::Expired
| CachePhase::Hit
| CachePhase::Revalidated
@ -1002,7 +1021,7 @@ impl HttpCache {
/// Whether this request's cache hit is staled
fn has_staled_asset(&self) -> bool {
self.phase == CachePhase::Stale
matches!(self.phase, CachePhase::Stale | CachePhase::StaleUpdating)
}
/// Whether this asset is staled and stale if error is allowed
@ -1063,7 +1082,10 @@ impl HttpCache {
let inner = self.inner_mut();
let mut span = inner.traces.child("purge");
let key = inner.key.as_ref().unwrap().to_compact();
let result = inner.storage.purge(&key, &span.handle()).await;
let result = inner
.storage
.purge(&key, PurgeType::Invalidation, &span.handle())
.await;
// FIXME: also need to remove from eviction manager
span.set_tag(|| trace::Tag::new("purged", matches!(result, Ok(true))));
result

View file

@ -306,7 +306,12 @@ impl Storage for MemCache {
Ok(Box::new(miss_handler))
}
async fn purge(&'static self, key: &CompactCacheKey, _trace: &SpanHandle) -> Result<bool> {
async fn purge(
&'static self,
key: &CompactCacheKey,
_type: PurgeType,
_trace: &SpanHandle,
) -> Result<bool> {
// This usually purges the primary key because, without a lookup, the variance key is usually
// empty
let hash = key.combined();
@ -525,7 +530,9 @@ mod test {
assert!(cache.temp.read().contains_key(&hash));
let result = cache.purge(&key, &Span::inactive().handle()).await;
let result = cache
.purge(&key, PurgeType::Invalidation, &Span::inactive().handle())
.await;
assert!(result.is_ok());
assert!(!cache.temp.read().contains_key(&hash));
@ -551,7 +558,9 @@ mod test {
assert!(cache.cached.read().contains_key(&hash));
let result = cache.purge(&key, &Span::inactive().handle()).await;
let result = cache
.purge(&key, PurgeType::Invalidation, &Span::inactive().handle())
.await;
assert!(result.is_ok());
assert!(!cache.cached.read().contains_key(&hash));

View file

@ -119,7 +119,7 @@ impl<C: CachePut> CachePutCtx<C> {
.handle();
for item in evicted {
// TODO: warn/log the error
let _ = self.storage.purge(&item, &trace).await;
let _ = self.storage.purge(&item, PurgeType::Eviction, &trace).await;
}
}

View file

@ -22,6 +22,15 @@ use async_trait::async_trait;
use pingora_error::Result;
use std::any::Any;
/// The reason a purge() is called
#[derive(Debug, Clone, Copy)]
pub enum PurgeType {
// For eviction because the cache storage is full
Eviction,
// For cache invalidation
Invalidation,
}
/// Cache storage interface
#[async_trait]
pub trait Storage {
@ -45,7 +54,12 @@ pub trait Storage {
/// Delete the cached asset for the given key
///
/// [CompactCacheKey] is used here because it is how eviction managers store the keys
async fn purge(&'static self, key: &CompactCacheKey, trace: &SpanHandle) -> Result<bool>;
async fn purge(
&'static self,
key: &CompactCacheKey,
purge_type: PurgeType,
trace: &SpanHandle,
) -> Result<bool>;
/// Update cache header and metadata for the already stored asset.
async fn update_meta(

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-core"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
@ -19,13 +19,13 @@ name = "pingora_core"
path = "src/lib.rs"
[dependencies]
pingora-runtime = { version = "0.2.0", path = "../pingora-runtime" }
pingora-openssl = { version = "0.2.0", path = "../pingora-openssl", optional = true }
pingora-boringssl = { version = "0.2.0", path = "../pingora-boringssl", optional = true }
pingora-pool = { version = "0.2.0", path = "../pingora-pool" }
pingora-error = { version = "0.2.0", path = "../pingora-error" }
pingora-timeout = { version = "0.2.0", path = "../pingora-timeout" }
pingora-http = { version = "0.2.0", path = "../pingora-http" }
pingora-runtime = { version = "0.3.0", path = "../pingora-runtime" }
pingora-openssl = { version = "0.3.0", path = "../pingora-openssl", optional = true }
pingora-boringssl = { version = "0.3.0", path = "../pingora-boringssl", optional = true }
pingora-pool = { version = "0.3.0", path = "../pingora-pool" }
pingora-error = { version = "0.3.0", path = "../pingora-error" }
pingora-timeout = { version = "0.3.0", path = "../pingora-timeout" }
pingora-http = { version = "0.3.0", path = "../pingora-http" }
tokio = { workspace = true, features = ["rt-multi-thread", "signal"] }
futures = "0.3"
async-trait = { workspace = true }
@ -81,4 +81,4 @@ jemallocator = "0.5"
default = ["openssl"]
openssl = ["pingora-openssl"]
boringssl = ["pingora-boringssl"]
patched_http1 = []
patched_http1 = []

View file

@ -28,7 +28,7 @@ use crate::protocols::Stream;
use crate::server::ShutdownWatch;
/// This trait defines how to map a request to a response
#[cfg_attr(not(doc_async_trait), async_trait)]
#[async_trait]
pub trait ServeHttp {
/// Define the mapping from a request to a response.
/// Note that the request header is already read, but the implementation needs to read the
@ -42,7 +42,7 @@ pub trait ServeHttp {
}
// TODO: remove this in favor of HttpServer?
#[cfg_attr(not(doc_async_trait), async_trait)]
#[async_trait]
impl<SV> HttpServerApp for SV
where
SV: ServeHttp + Send + Sync,
@ -128,7 +128,7 @@ impl<SV> HttpServer<SV> {
}
}
#[cfg_attr(not(doc_async_trait), async_trait)]
#[async_trait]
impl<SV> HttpServerApp for HttpServer<SV>
where
SV: ServeHttp + Send + Sync,

View file

@ -28,7 +28,7 @@ use crate::protocols::Digest;
use crate::protocols::Stream;
use crate::protocols::ALPN;
#[cfg_attr(not(doc_async_trait), async_trait)]
#[async_trait]
/// This trait defines the interface of a transport layer (TCP or TLS) application.
pub trait ServerApp {
/// Whenever a new connection is established, this function will be called with the established
@ -62,7 +62,7 @@ pub struct HttpServerOptions {
}
/// This trait defines the interface of an HTTP application.
#[cfg_attr(not(doc_async_trait), async_trait)]
#[async_trait]
pub trait HttpServerApp {
/// Similar to the [`ServerApp`], this function is called whenever a new HTTP session is established.
///
@ -95,7 +95,7 @@ pub trait HttpServerApp {
async fn http_cleanup(&self) {}
}
#[cfg_attr(not(doc_async_trait), async_trait)]
#[async_trait]
impl<T> ServerApp for T
where
T: HttpServerApp + Send + Sync + 'static,

View file

@ -29,7 +29,7 @@ use crate::protocols::http::ServerSession;
/// collected via the [Prometheus](https://docs.rs/prometheus/) crate;
pub struct PrometheusHttpApp;
#[cfg_attr(not(doc_async_trait), async_trait)]
#[async_trait]
impl ServeHttp for PrometheusHttpApp {
async fn response(&self, _http_session: &mut ServerSession) -> Response<Vec<u8>> {
let encoder = TextEncoder::new();

View file

@ -51,10 +51,11 @@ impl Connector {
pub async fn release_http_session<P: Peer + Send + Sync + 'static>(
&self,
session: HttpSession,
mut session: HttpSession,
peer: &P,
idle_timeout: Option<Duration>,
) {
session.respect_keepalive();
if let Some(stream) = session.reuse().await {
self.transport
.release_stream(stream, peer.reuse_hash(), idle_timeout);

View file

@ -21,7 +21,7 @@ use crate::upstreams::peer::{Peer, ALPN};
use bytes::Bytes;
use h2::client::SendRequest;
use log::{debug, warn};
use log::debug;
use parking_lot::{Mutex, RwLock};
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use pingora_pool::{ConnectionMeta, ConnectionPool, PoolNode};
@ -52,6 +52,8 @@ pub(crate) struct ConnectionRefInner {
max_streams: usize,
// how many concurrent streams already active
current_streams: AtomicUsize,
// The connection is gracefully shutting down, no more stream is allowed
shutting_down: AtomicBool,
// because `SendRequest` doesn't actually have access to the underlying Stream,
// we log info about timing and tcp info here.
pub(crate) digest: Digest,
@ -78,12 +80,14 @@ impl ConnectionRef {
id,
max_streams,
current_streams: AtomicUsize::new(0),
shutting_down: false.into(),
digest,
release_lock: Arc::new(Mutex::new(())),
}))
}
pub fn more_streams_allowed(&self) -> bool {
self.0.max_streams > self.0.current_streams.load(Ordering::Relaxed)
!self.is_shutting_down()
&& self.0.max_streams > self.0.current_streams.load(Ordering::Relaxed)
}
pub fn is_idle(&self) -> bool {
@ -102,6 +106,10 @@ impl ConnectionRef {
&self.0.digest
}
pub fn digest_mut(&mut self) -> Option<&mut Digest> {
Arc::get_mut(&mut self.0).map(|inner| &mut inner.digest)
}
pub fn ping_timedout(&self) -> bool {
self.0.ping_timeout_occurred.load(Ordering::Relaxed)
}
@ -110,6 +118,12 @@ impl ConnectionRef {
*self.0.closed.borrow()
}
// different from is_closed, existing streams can still be processed but can no longer create
// new stream.
pub fn is_shutting_down(&self) -> bool {
self.0.shutting_down.load(Ordering::Relaxed)
}
// spawn a stream if more stream is allowed, otherwise return Ok(None)
pub async fn spawn_stream(&self) -> Result<Option<Http2Session>> {
// Atomically check if the current_stream is over the limit
@ -120,13 +134,28 @@ impl ConnectionRef {
self.0.current_streams.fetch_sub(1, Ordering::SeqCst);
return Ok(None);
}
let send_req = self.0.connection_stub.new_stream().await.map_err(|e| {
// fail to create the stream, reset the counter
self.0.current_streams.fetch_sub(1, Ordering::SeqCst);
e
})?;
Ok(Some(Http2Session::new(send_req, self.clone())))
match self.0.connection_stub.new_stream().await {
Ok(send_req) => Ok(Some(Http2Session::new(send_req, self.clone()))),
Err(e) => {
// fail to create the stream, reset the counter
self.0.current_streams.fetch_sub(1, Ordering::SeqCst);
// Remote sends GOAWAY(NO_ERROR): graceful shutdown: this connection no longer
// accepts new streams. We can still try to create new connection.
if e.root_cause()
.downcast_ref::<h2::Error>()
.map(|e| {
e.is_go_away() && e.is_remote() && e.reason() == Some(h2::Reason::NO_ERROR)
})
.unwrap_or(false)
{
self.0.shutting_down.store(true, Ordering::Relaxed);
Ok(None)
} else {
Err(e)
}
}
}
}
}
@ -273,11 +302,6 @@ impl Connector {
.or_else(|| self.idle_pool.get(&reuse_hash));
if let Some(conn) = maybe_conn {
let h2_stream = conn.spawn_stream().await?;
if h2_stream.is_none() {
warn!("connection from the pools should have free stream to allocate, current in use {}, max {}",
conn.0.current_streams.load(Ordering::Relaxed),
conn.0.max_streams);
}
if conn.more_streams_allowed() {
self.in_use_pool.insert(reuse_hash, conn);
}
@ -314,8 +338,8 @@ impl Connector {
// find and remove the conn stored in in_use_pool so that it could be put in the idle pool
// if necessary
let conn = self.in_use_pool.release(reuse_hash, id).unwrap_or(conn);
if conn.is_closed() {
// Already dead h2 connection
if conn.is_closed() || conn.is_shutting_down() {
// should never be put back to the pool
return;
}
if conn.is_idle() {

View file

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use log::debug;
use pingora_error::{Context, Error, ErrorType::*, OrErr, Result};
use rand::seq::SliceRandom;
@ -19,13 +20,19 @@ use std::net::SocketAddr as InetSocketAddr;
use std::os::unix::io::AsRawFd;
use crate::protocols::l4::ext::{
connect_uds, connect_with as tcp_connect, set_recv_buf, set_tcp_fastopen_connect,
connect_uds, connect_with as tcp_connect, set_dscp, set_recv_buf, set_tcp_fastopen_connect,
};
use crate::protocols::l4::socket::SocketAddr;
use crate::protocols::l4::stream::Stream;
use crate::protocols::{GetSocketDigest, SocketDigest};
use crate::upstreams::peer::Peer;
/// The interface to establish a L4 connection
#[async_trait]
pub trait Connect: std::fmt::Debug {
async fn connect(&self, addr: &SocketAddr) -> Result<Stream>;
}
/// Establish a connection (l4) to the given peer using its settings and an optional bind address.
pub async fn connect<P>(peer: &P, bind_to: Option<InetSocketAddr>) -> Result<Stream>
where
@ -37,68 +44,78 @@ where
.err_context(|| format!("Fail to establish CONNECT proxy: {}", peer));
}
let peer_addr = peer.address();
let mut stream: Stream = match peer_addr {
SocketAddr::Inet(addr) => {
let connect_future = tcp_connect(addr, bind_to.as_ref(), |socket| {
if peer.tcp_fast_open() {
set_tcp_fastopen_connect(socket.as_raw_fd())?;
}
if let Some(recv_buf) = peer.tcp_recv_buf() {
debug!("Setting recv buf size");
set_recv_buf(socket.as_raw_fd(), recv_buf)?;
}
Ok(())
});
let conn_res = match peer.connection_timeout() {
Some(t) => pingora_timeout::timeout(t, connect_future)
.await
.explain_err(ConnectTimedout, |_| {
format!("timeout {t:?} connecting to server {peer}")
})?,
None => connect_future.await,
};
match conn_res {
Ok(socket) => {
debug!("connected to new server: {}", peer.address());
Ok(socket.into())
}
Err(e) => {
let c = format!("Fail to connect to {peer}");
match e.etype() {
SocketError | BindError => Error::e_because(InternalError, c, e),
_ => Err(e.more_context(c)),
let mut stream: Stream =
if let Some(custom_l4) = peer.get_peer_options().and_then(|o| o.custom_l4.as_ref()) {
custom_l4.connect(peer_addr).await?
} else {
match peer_addr {
SocketAddr::Inet(addr) => {
let connect_future = tcp_connect(addr, bind_to.as_ref(), |socket| {
if peer.tcp_fast_open() {
set_tcp_fastopen_connect(socket.as_raw_fd())?;
}
if let Some(recv_buf) = peer.tcp_recv_buf() {
debug!("Setting recv buf size");
set_recv_buf(socket.as_raw_fd(), recv_buf)?;
}
if let Some(dscp) = peer.dscp() {
debug!("Setting dscp");
set_dscp(socket.as_raw_fd(), dscp)?;
}
Ok(())
});
let conn_res = match peer.connection_timeout() {
Some(t) => pingora_timeout::timeout(t, connect_future)
.await
.explain_err(ConnectTimedout, |_| {
format!("timeout {t:?} connecting to server {peer}")
})?,
None => connect_future.await,
};
match conn_res {
Ok(socket) => {
debug!("connected to new server: {}", peer.address());
Ok(socket.into())
}
Err(e) => {
let c = format!("Fail to connect to {peer}");
match e.etype() {
SocketError | BindError => Error::e_because(InternalError, c, e),
_ => Err(e.more_context(c)),
}
}
}
}
}
}
SocketAddr::Unix(addr) => {
let connect_future = connect_uds(
addr.as_pathname()
.expect("non-pathname unix sockets not supported as peer"),
);
let conn_res = match peer.connection_timeout() {
Some(t) => pingora_timeout::timeout(t, connect_future)
.await
.explain_err(ConnectTimedout, |_| {
format!("timeout {t:?} connecting to server {peer}")
})?,
None => connect_future.await,
};
match conn_res {
Ok(socket) => {
debug!("connected to new server: {}", peer.address());
Ok(socket.into())
}
Err(e) => {
let c = format!("Fail to connect to {peer}");
match e.etype() {
SocketError | BindError => Error::e_because(InternalError, c, e),
_ => Err(e.more_context(c)),
SocketAddr::Unix(addr) => {
let connect_future = connect_uds(
addr.as_pathname()
.expect("non-pathname unix sockets not supported as peer"),
);
let conn_res = match peer.connection_timeout() {
Some(t) => pingora_timeout::timeout(t, connect_future)
.await
.explain_err(ConnectTimedout, |_| {
format!("timeout {t:?} connecting to server {peer}")
})?,
None => connect_future.await,
};
match conn_res {
Ok(socket) => {
debug!("connected to new server: {}", peer.address());
Ok(socket.into())
}
Err(e) => {
let c = format!("Fail to connect to {peer}");
match e.etype() {
SocketError | BindError => Error::e_because(InternalError, c, e),
_ => Err(e.more_context(c)),
}
}
}
}
}
}
}?;
}?
};
let tracer = peer.get_tracer();
if let Some(t) = tracer {
t.0.on_connected();
@ -245,6 +262,29 @@ mod tests {
assert_eq!(new_session.unwrap_err().etype(), &ConnectTimedout)
}
#[tokio::test]
async fn test_custom_connect() {
#[derive(Debug)]
struct MyL4;
#[async_trait]
impl Connect for MyL4 {
async fn connect(&self, _addr: &SocketAddr) -> Result<Stream> {
tokio::net::TcpStream::connect("1.1.1.1:80")
.await
.map(|s| s.into())
.or_fail()
}
}
// :79 shouldn't be able to be connected to
let mut peer = BasicPeer::new("1.1.1.1:79");
peer.options.custom_l4 = Some(std::sync::Arc::new(MyL4 {}));
let new_session = connect(&peer, None).await;
// but MyL4 connects to :80 instead
assert!(new_session.is_ok());
}
#[tokio::test]
async fn test_connect_proxy_fail() {
let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());

View file

@ -25,6 +25,7 @@ use crate::tls::ssl::SslConnector;
use crate::upstreams::peer::{Peer, ALPN};
use l4::connect as l4_connect;
pub use l4::Connect as L4Connect;
use log::{debug, error, warn};
use offload::OffloadRuntime;
use parking_lot::RwLock;

View file

@ -18,8 +18,6 @@
#![allow(clippy::match_wild_err_arm)]
#![allow(clippy::missing_safety_doc)]
#![allow(clippy::upper_case_acronyms)]
// enable nightly feature async trait so that the docs are cleaner
#![cfg_attr(doc_async_trait, feature(async_fn_in_trait))]
//! # Pingora
//!

View file

@ -25,7 +25,7 @@ use std::os::unix::net::UnixListener as StdUnixListener;
use std::time::Duration;
use tokio::net::TcpSocket;
use crate::protocols::l4::ext::set_tcp_fastopen_backlog;
use crate::protocols::l4::ext::{set_dscp, set_tcp_fastopen_backlog};
use crate::protocols::l4::listener::Listener;
pub use crate::protocols::l4::stream::Stream;
use crate::protocols::TcpKeepalive;
@ -76,6 +76,9 @@ pub struct TcpSocketOptions {
/// Enable TCP keepalive on accepted connections.
/// See the [man page](https://man7.org/linux/man-pages/man7/tcp.7.html) for more information.
pub tcp_keepalive: Option<TcpKeepalive>,
/// Specifies the server should set the following DSCP value on outgoing connections.
/// See the [RFC](https://datatracker.ietf.org/doc/html/rfc2474) for more details.
pub dscp: Option<u8>,
// TODO: allow configuring reuseaddr, backlog, etc. from here?
}
@ -150,6 +153,10 @@ fn apply_tcp_socket_options(sock: &TcpSocket, opt: Option<&TcpSocketOptions>) ->
if let Some(backlog) = opt.tcp_fastopen {
set_tcp_fastopen_backlog(sock.as_raw_fd(), backlog)?;
}
if let Some(dscp) = opt.dscp {
set_dscp(sock.as_raw_fd(), dscp)?;
}
Ok(())
}
@ -280,6 +287,9 @@ impl ListenerEndpoint {
if let Some(ka) = op.tcp_keepalive.as_ref() {
stream.set_keepalive(ka)?;
}
if let Some(dscp) = op.dscp {
set_dscp(stream.as_raw_fd(), dscp)?;
}
Ok(())
}

View file

@ -38,6 +38,15 @@ pub struct TlsSettings {
callbacks: Option<TlsAcceptCallbacks>,
}
impl From<SslAcceptorBuilder> for TlsSettings {
fn from(settings: SslAcceptorBuilder) -> Self {
TlsSettings {
accept_builder: settings,
callbacks: None,
}
}
}
impl Deref for TlsSettings {
type Target = SslAcceptorBuilder;

View file

@ -65,7 +65,7 @@ pub trait HttpModule {
fn as_any_mut(&mut self) -> &mut dyn Any;
}
type Module = Box<dyn HttpModule + 'static + Send + Sync>;
pub type Module = Box<dyn HttpModule + 'static + Send + Sync>;
/// Trait to init the http module ctx for each request
pub trait HttpModuleBuilder {

View file

@ -89,7 +89,9 @@ impl HttpSession {
/// Set the write timeout for writing header and body.
///
/// The timeout is per write operation, not on the overall time writing the entire request
/// The timeout is per write operation, not on the overall time writing the entire request.
///
/// This is a noop for h2.
pub fn set_write_timeout(&mut self, timeout: Duration) {
match self {
HttpSession::H1(h1) => h1.write_timeout = Some(timeout),
@ -151,7 +153,7 @@ impl HttpSession {
/// Return the [Digest] of the connection
///
/// For reused connection, the timing in the digest will reflect its initial handshakes
/// The caller should check if the connection is reused to avoid misuse of the timing field
/// The caller should check if the connection is reused to avoid misuse of the timing field.
pub fn digest(&self) -> Option<&Digest> {
match self {
Self::H1(s) => Some(s.digest()),
@ -159,6 +161,16 @@ impl HttpSession {
}
}
/// Return a mutable [Digest] reference for the connection, see [`digest`] for more details.
///
/// Will return `None` if this is an H2 session and multiple streams are open.
pub fn digest_mut(&mut self) -> Option<&mut Digest> {
match self {
Self::H1(s) => Some(s.digest_mut()),
Self::H2(s) => s.digest_mut(),
}
}
/// Return the server (peer) address of the connection.
pub fn server_addr(&self) -> Option<&SocketAddr> {
match self {

View file

@ -42,7 +42,6 @@ impl Decompressor {
impl Encode for Decompressor {
fn encode(&mut self, input: &[u8], end: bool) -> Result<Bytes> {
// reserve at most 16k
const MAX_INIT_COMPRESSED_SIZE_CAP: usize = 4 * 1024;
// Brotli compress ratio can be 3.5 to 4.5
const ESTIMATED_COMPRESSION_RATIO: usize = 4;

View file

@ -12,15 +12,65 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::Encode;
use super::{Encode, COMPRESSION_ERROR};
use bytes::Bytes;
use flate2::write::GzEncoder;
use pingora_error::Result;
use flate2::write::{GzDecoder, GzEncoder};
use pingora_error::{OrErr, Result};
use std::io::Write;
use std::time::{Duration, Instant};
// TODO: unzip
pub struct Decompressor {
decompress: GzDecoder<Vec<u8>>,
total_in: usize,
total_out: usize,
duration: Duration,
}
impl Decompressor {
pub fn new() -> Self {
Decompressor {
decompress: GzDecoder::new(vec![]),
total_in: 0,
total_out: 0,
duration: Duration::new(0, 0),
}
}
}
impl Encode for Decompressor {
fn encode(&mut self, input: &[u8], end: bool) -> Result<Bytes> {
const MAX_INIT_COMPRESSED_SIZE_CAP: usize = 4 * 1024;
const ESTIMATED_COMPRESSION_RATIO: usize = 3; // estimated 2.5-3x compression
let start = Instant::now();
self.total_in += input.len();
// cap the buf size amplification, there is a DoS risk of always allocate
// 3x the memory of the input buffer
let reserve_size = if input.len() < MAX_INIT_COMPRESSED_SIZE_CAP {
input.len() * ESTIMATED_COMPRESSION_RATIO
} else {
input.len()
};
self.decompress.get_mut().reserve(reserve_size);
self.decompress
.write_all(input)
.or_err(COMPRESSION_ERROR, "while decompress Gzip")?;
// write to vec will never fail, only possible error is that the input data
// was not actually gzip compressed
if end {
self.decompress
.try_finish()
.or_err(COMPRESSION_ERROR, "while decompress Gzip")?;
}
self.total_out += self.decompress.get_ref().len();
self.duration += start.elapsed();
Ok(std::mem::take(self.decompress.get_mut()).into()) // into() Bytes will drop excess capacity
}
fn stat(&self) -> (&'static str, usize, usize, Duration) {
("de-gzip", self.total_in, self.total_out, self.duration)
}
}
pub struct Compressor {
// TODO: enum for other compression algorithms
@ -66,6 +116,20 @@ impl Encode for Compressor {
}
use std::ops::{Deref, DerefMut};
impl Deref for Decompressor {
type Target = GzDecoder<Vec<u8>>;
fn deref(&self) -> &Self::Target {
&self.decompress
}
}
impl DerefMut for Decompressor {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.decompress
}
}
impl Deref for Compressor {
type Target = GzEncoder<Vec<u8>>;
@ -100,4 +164,21 @@ mod tests_stream {
assert!(compressor.get_ref().is_empty());
}
#[test]
fn gunzip_data() {
let mut decompressor = Decompressor::new();
let compressed_bytes = &[
0x1f, 0x8b, 0x08, 0, 0, 0, 0, 0, 0, 255, 75, 76, 74, 78, 73, 77, 75, 7, 0, 166, 106,
42, 49, 7, 0, 0, 0,
];
let decompressed = decompressor.encode(compressed_bytes, true).unwrap();
assert_eq!(&decompressed[..], b"abcdefg");
assert_eq!(decompressor.total_in, compressed_bytes.len());
assert_eq!(decompressor.total_out, decompressed.len());
assert!(decompressor.get_ref().is_empty());
}
}

View file

@ -66,10 +66,10 @@ pub struct ResponseCompressionCtx(CtxInner);
enum CtxInner {
HeaderPhase {
decompress_enable: bool,
// Store the preferred list to compare with content-encoding
accept_encoding: Vec<Algorithm>,
encoding_levels: [u32; Algorithm::COUNT],
decompress_enable: [bool; Algorithm::COUNT],
},
BodyPhase(Option<Box<dyn Encode + Send + Sync>>),
}
@ -80,9 +80,9 @@ impl ResponseCompressionCtx {
/// The `decompress_enable` flag will tell the ctx to decompress if needed.
pub fn new(compression_level: u32, decompress_enable: bool) -> Self {
Self(CtxInner::HeaderPhase {
decompress_enable,
accept_encoding: Vec::new(),
encoding_levels: [compression_level; Algorithm::COUNT],
decompress_enable: [decompress_enable; Algorithm::COUNT],
})
}
@ -92,9 +92,9 @@ impl ResponseCompressionCtx {
match &self.0 {
CtxInner::HeaderPhase {
decompress_enable,
accept_encoding: _,
encoding_levels: levels,
} => levels.iter().any(|l| *l != 0) || *decompress_enable,
..
} => levels.iter().any(|l| *l != 0) || decompress_enable.iter().any(|d| *d),
CtxInner::BodyPhase(c) => c.is_some(),
}
}
@ -103,11 +103,7 @@ impl ResponseCompressionCtx {
/// algorithm name, in bytes, out bytes, time took for the compression
pub fn get_info(&self) -> Option<(&'static str, usize, usize, Duration)> {
match &self.0 {
CtxInner::HeaderPhase {
decompress_enable: _,
accept_encoding: _,
encoding_levels: _,
} => None,
CtxInner::HeaderPhase { .. } => None,
CtxInner::BodyPhase(c) => c.as_ref().map(|c| c.stat()),
}
}
@ -118,9 +114,8 @@ impl ResponseCompressionCtx {
pub fn adjust_level(&mut self, new_level: u32) {
match &mut self.0 {
CtxInner::HeaderPhase {
decompress_enable: _,
accept_encoding: _,
encoding_levels: levels,
..
} => {
*levels = [new_level; Algorithm::COUNT];
}
@ -134,9 +129,8 @@ impl ResponseCompressionCtx {
pub fn adjust_algorithm_level(&mut self, algorithm: Algorithm, new_level: u32) {
match &mut self.0 {
CtxInner::HeaderPhase {
decompress_enable: _,
accept_encoding: _,
encoding_levels: levels,
..
} => {
levels[algorithm.index()] = new_level;
}
@ -144,17 +138,29 @@ impl ResponseCompressionCtx {
}
}
/// Adjust the decompression flag.
/// Adjust the decompression flag for all compression algorithms.
/// # Panic
/// This function will panic if it has already started encoding the response body.
pub fn adjust_decompression(&mut self, enabled: bool) {
match &mut self.0 {
CtxInner::HeaderPhase {
decompress_enable,
accept_encoding: _,
encoding_levels: _,
decompress_enable, ..
} => {
*decompress_enable = enabled;
*decompress_enable = [enabled; Algorithm::COUNT];
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
/// Adjust the decompression flag for a specific algorithm.
/// # Panic
/// This function will panic if it has already started encoding the response body.
pub fn adjust_algorithm_decompression(&mut self, algorithm: Algorithm, enabled: bool) {
match &mut self.0 {
CtxInner::HeaderPhase {
decompress_enable, ..
} => {
decompress_enable[algorithm.index()] = enabled;
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
@ -206,7 +212,7 @@ impl ResponseCompressionCtx {
if depends_on_accept_encoding(
resp,
levels.iter().any(|level| *level != 0),
*decompress_enable,
decompress_enable,
) {
// The response depends on the Accept-Encoding header, make sure to indicate it
// in the Vary response header.
@ -218,7 +224,9 @@ impl ResponseCompressionCtx {
let encoder = match action {
Action::Noop => None,
Action::Compress(algorithm) => algorithm.compressor(levels[algorithm.index()]),
Action::Decompress(algorithm) => algorithm.decompressor(*decompress_enable),
Action::Decompress(algorithm) => {
algorithm.decompressor(decompress_enable[algorithm.index()])
}
};
if encoder.is_some() {
adjust_response_header(resp, &action);
@ -327,6 +335,7 @@ impl Algorithm {
None
} else {
match self {
Self::Gzip => Some(Box::new(gzip::Decompressor::new())),
Self::Brotli => Some(Box::new(brotli::Decompressor::new())),
_ => None, // not implemented
}
@ -433,11 +442,12 @@ fn test_accept_encoding_req_header() {
fn depends_on_accept_encoding(
resp: &ResponseHeader,
compress_enabled: bool,
decompress_enabled: bool,
decompress_enabled: &[bool],
) -> bool {
use http::header::CONTENT_ENCODING;
(decompress_enabled && resp.headers.get(CONTENT_ENCODING).is_some())
(decompress_enabled.iter().any(|enabled| *enabled)
&& resp.headers.get(CONTENT_ENCODING).is_some())
|| (compress_enabled && compressible(resp))
}

View file

@ -25,6 +25,7 @@ use http::{header::AsHeaderName, HeaderMap};
use log::error;
use pingora_error::Result;
use pingora_http::{RequestHeader, ResponseHeader};
use std::time::Duration;
/// HTTP server session object for both HTTP/1.x and HTTP/2
pub enum Session {
@ -52,7 +53,7 @@ impl Session {
/// else with the session.
/// - `Ok(true)`: successful
/// - `Ok(false)`: client exit without sending any bytes. This is normal on reused connection.
/// In this case the user should give up this session.
/// In this case the user should give up this session.
pub async fn read_request(&mut self) -> Result<bool> {
match self {
Self::H1(s) => {
@ -188,6 +189,48 @@ impl Session {
}
}
/// Sets the downstream write timeout. This will trigger if we're unable
/// to write to the stream after `duration`. If a `min_send_rate` is
/// configured then the `min_send_rate` calculated timeout has higher priority.
///
/// This is a noop for h2.
pub fn set_write_timeout(&mut self, timeout: Duration) {
match self {
Self::H1(s) => s.set_write_timeout(timeout),
Self::H2(_) => {}
}
}
/// Sets the minimum downstream send rate in bytes per second. This
/// is used to calculate a write timeout in seconds based on the size
/// of the buffer being written. If a `min_send_rate` is configured it
/// has higher priority over a set `write_timeout`. The minimum send
/// rate must be greater than zero.
///
/// Calculated write timeout is guaranteed to be at least 1s if `min_send_rate`
/// is greater than zero, a send rate of zero is a noop.
///
/// This is a noop for h2.
pub fn set_min_send_rate(&mut self, rate: usize) {
match self {
Self::H1(s) => s.set_min_send_rate(rate),
Self::H2(_) => {}
}
}
/// Sets whether we ignore writing informational responses downstream.
///
/// For HTTP/1.1 this is a noop if the response is Upgrade or Continue and
/// Expect: 100-continue was set on the request.
///
/// This is a noop for h2 because informational responses are always ignored.
pub fn set_ignore_info_resp(&mut self, ignore: bool) {
match self {
Self::H1(s) => s.set_ignore_info_resp(ignore),
Self::H2(_) => {} // always ignored
}
}
/// Return a digest of the request including the method, path and Host header
// TODO: make this use a `Formatter`
pub fn request_summary(&self) -> String {
@ -357,7 +400,7 @@ impl Session {
}
}
/// Return the digest for the session.
/// Return the [Digest] for the connection.
pub fn digest(&self) -> Option<&Digest> {
match self {
Self::H1(s) => Some(s.digest()),
@ -365,6 +408,16 @@ impl Session {
}
}
/// Return a mutable [Digest] reference for the connection.
///
/// Will return `None` if multiple H2 streams are open.
pub fn digest_mut(&mut self) -> Option<&mut Digest> {
match self {
Self::H1(s) => Some(s.digest_mut()),
Self::H2(s) => s.digest_mut(),
}
}
/// Return the client (peer) address of the connection.
pub fn client_addr(&self) -> Option<&SocketAddr> {
match self {

View file

@ -631,10 +631,19 @@ impl HttpSession {
// TODO: support h1 trailer
}
/// Return the [Digest] of the connection
///
/// For reused connection, the timing in the digest will reflect its initial handshakes
/// The caller should check if the connection is reused to avoid misuse the timing field.
pub fn digest(&self) -> &Digest {
&self.digest
}
/// Return a mutable [Digest] reference for the connection, see [`digest`] for more details.
pub fn digest_mut(&mut self) -> &mut Digest {
&mut self.digest
}
/// Return the server (peer) address recorded in the connection digest.
pub fn server_addr(&self) -> Option<&SocketAddr> {
self.digest()

View file

@ -153,6 +153,14 @@ pub(super) fn is_upgrade_req(req: &RequestHeader) -> bool {
req.version == http::Version::HTTP_11 && req.headers.get(header::UPGRADE).is_some()
}
pub(super) fn is_expect_continue_req(req: &RequestHeader) -> bool {
req.version == http::Version::HTTP_11
// https://www.rfc-editor.org/rfc/rfc9110#section-10.1.1
&& req.headers.get(header::EXPECT).map_or(false, |v| {
v.as_bytes().eq_ignore_ascii_case(b"100-continue")
})
}
// Unlike the upgrade check on request, this function doesn't check the Upgrade or Connection header
// because when seeing 101, we assume the server accepts to switch protocol.
// In reality it is not common that some servers don't send all the required headers to establish

View file

@ -72,6 +72,10 @@ pub struct HttpSession {
upgraded: bool,
/// Digest to track underlying connection metrics
digest: Box<Digest>,
/// Minimum send rate to the client
min_send_rate: Option<usize>,
/// When this is enabled informational response headers will not be proxied downstream
ignore_info_resp: bool,
}
impl HttpSession {
@ -106,6 +110,8 @@ impl HttpSession {
retry_buffer: None,
upgraded: false,
digest,
min_send_rate: None,
ignore_info_resp: false,
}
}
@ -385,6 +391,11 @@ impl HttpSession {
/// Write the response header to the client.
/// This function can be called more than once to send 1xx informational headers excluding 101.
pub async fn write_response_header(&mut self, mut header: Box<ResponseHeader>) -> Result<()> {
if header.status.is_informational() && self.ignore_info_resp(header.status.into()) {
debug!("ignoring informational headers");
return Ok(());
}
if let Some(resp) = self.response_written.as_ref() {
if !resp.status.is_informational() || self.upgraded {
warn!("Respond header is already sent, cannot send again");
@ -406,7 +417,7 @@ impl HttpSession {
header.insert_header(header::CONNECTION, connection_value)?;
}
if header.status.as_u16() == 101 {
if header.status == 101 {
// make sure the connection is closed at the end when 101/upgrade is used
self.set_keepalive(None);
}
@ -507,10 +518,34 @@ impl HttpSession {
(None, None)
}
fn ignore_info_resp(&self, status: u16) -> bool {
// ignore informational response if ignore flag is set and it's not an Upgrade and Expect: 100-continue isn't set
self.ignore_info_resp && status != 101 && !(status == 100 && self.is_expect_continue_req())
}
fn is_expect_continue_req(&self) -> bool {
match self.request_header.as_deref() {
Some(req) => is_expect_continue_req(req),
None => false,
}
}
fn is_connection_keepalive(&self) -> Option<bool> {
is_buf_keepalive(self.get_header(header::CONNECTION))
}
// calculate write timeout from min_send_rate if set, otherwise return write_timeout
fn write_timeout(&self, buf_len: usize) -> Option<Duration> {
let Some(min_send_rate) = self.min_send_rate.filter(|r| *r > 0) else {
return self.write_timeout;
};
// min timeout is 1s
let ms = (buf_len.max(min_send_rate) as f64 / min_send_rate as f64) * 1000.0;
// truncates unrealistically large values (we'll be out of memory before this happens)
Some(Duration::from_millis(ms as u64))
}
/// Apply keepalive settings according to the client
/// For HTTP 1.1, assume keepalive as long as there is no `Connection: Close` request header.
/// For HTTP 1.0, only keepalive if there is an explicit header `Connection: keep-alive`.
@ -579,7 +614,7 @@ impl HttpSession {
/// to be written, e.g., writing more bytes than what the `Content-Length` header suggests
pub async fn write_body(&mut self, buf: &[u8]) -> Result<Option<usize>> {
// TODO: check if the response header is written
match self.write_timeout {
match self.write_timeout(buf.len()) {
Some(t) => match timeout(t, self.do_write_body(buf)).await {
Ok(res) => res,
Err(_) => Error::e_explain(WriteTimedout, format!("writing body, timeout: {t:?}")),
@ -588,7 +623,7 @@ impl HttpSession {
}
}
async fn write_body_buf(&mut self) -> Result<Option<usize>> {
async fn do_write_body_buf(&mut self) -> Result<Option<usize>> {
// Don't flush empty chunks, they are considered end of body for chunks
if self.body_write_buf.is_empty() {
return Ok(None);
@ -609,6 +644,16 @@ impl HttpSession {
written
}
async fn write_body_buf(&mut self) -> Result<Option<usize>> {
match self.write_timeout(self.body_write_buf.len()) {
Some(t) => match timeout(t, self.do_write_body_buf()).await {
Ok(res) => res,
Err(_) => Error::e_explain(WriteTimedout, format!("writing body, timeout: {t:?}")),
},
None => self.do_write_body_buf().await,
}
}
fn maybe_force_close_body_reader(&mut self) {
if self.upgraded && !self.body_reader.body_done() {
// response is done, reset the request body to close
@ -778,11 +823,45 @@ impl HttpSession {
}
}
/// Sets the downstream write timeout. This will trigger if we're unable
/// to write to the stream after `duration`. If a `min_send_rate` is
/// configured then the `min_send_rate` calculated timeout has higher priority.
pub fn set_write_timeout(&mut self, timeout: Duration) {
self.write_timeout = Some(timeout);
}
/// Sets the minimum downstream send rate in bytes per second. This
/// is used to calculate a write timeout in seconds based on the size
/// of the buffer being written. If a `min_send_rate` is configured it
/// has higher priority over a set `write_timeout`. The minimum send
/// rate must be greater than zero.
///
/// Calculated write timeout is guaranteed to be at least 1s if `min_send_rate`
/// is greater than zero, a send rate of zero is a noop.
pub fn set_min_send_rate(&mut self, min_send_rate: usize) {
if min_send_rate > 0 {
self.min_send_rate = Some(min_send_rate);
}
}
/// Sets whether we ignore writing informational responses downstream.
///
/// This is a noop if the response is Upgrade or Continue and
/// Expect: 100-continue was set on the request.
pub fn set_ignore_info_resp(&mut self, ignore: bool) {
self.ignore_info_resp = ignore;
}
/// Return the [Digest] of the connection.
pub fn digest(&self) -> &Digest {
&self.digest
}
/// Return a mutable [Digest] reference for the connection.
pub fn digest_mut(&mut self) -> &mut Digest {
&mut self.digest
}
/// Return the client (peer) address of the underlying connection.
pub fn client_addr(&self) -> Option<&SocketAddr> {
self.digest()
@ -1421,6 +1500,75 @@ mod tests_stream {
.unwrap();
}
#[tokio::test]
async fn write_informational_ignored() {
let wire = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().write(wire).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
// ignore the 100 Continue
http_stream.ignore_info_resp = true;
let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap();
http_stream
.write_response_header_ref(&response_100)
.await
.unwrap();
let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap();
response_200.append_header("Foo", "Bar").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&response_200)
.await
.unwrap();
}
#[tokio::test]
async fn write_informational_100_not_ignored_if_expect_continue() {
let input = b"GET / HTTP/1.1\r\nExpect: 100-continue\r\n\r\n";
let output = b"HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).write(output).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
http_stream.ignore_info_resp = true;
// 100 Continue is not ignored due to Expect: 100-continue on request
let response_100 = ResponseHeader::build(StatusCode::CONTINUE, None).unwrap();
http_stream
.write_response_header_ref(&response_100)
.await
.unwrap();
let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap();
response_200.append_header("Foo", "Bar").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&response_200)
.await
.unwrap();
}
#[tokio::test]
async fn write_informational_1xx_ignored_if_expect_continue() {
let input = b"GET / HTTP/1.1\r\nExpect: 100-continue\r\n\r\n";
let output = b"HTTP/1.1 200 OK\r\nFoo: Bar\r\n\r\n";
let mock_io = Builder::new().read(&input[..]).write(output).build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.read_request().await.unwrap();
http_stream.ignore_info_resp = true;
// 102 Processing is ignored
let response_102 = ResponseHeader::build(StatusCode::PROCESSING, None).unwrap();
http_stream
.write_response_header_ref(&response_102)
.await
.unwrap();
let mut response_200 = ResponseHeader::build(StatusCode::OK, None).unwrap();
response_200.append_header("Foo", "Bar").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&response_200)
.await
.unwrap();
}
#[tokio::test]
async fn write_101_switching_protocol() {
let wire = b"HTTP/1.1 101 Switching Protocols\r\nFoo: Bar\r\n\r\n";
@ -1583,6 +1731,30 @@ mod tests_stream {
assert!(written.is_none());
}
#[tokio::test]
#[should_panic(expected = "There is still data left to write.")]
async fn test_write_body_buf_write_timeout() {
let wire1 = b"HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\n";
let wire2 = b"abc";
let mock_io = Builder::new()
.write(wire1)
.wait(Duration::from_millis(500))
.write(wire2)
.build();
let mut http_stream = HttpSession::new(Box::new(mock_io));
http_stream.write_timeout = Some(Duration::from_millis(100));
let mut new_response = ResponseHeader::build(StatusCode::OK, None).unwrap();
new_response.append_header("Content-Length", "3").unwrap();
http_stream.update_resp_headers = false;
http_stream
.write_response_header_ref(&new_response)
.await
.unwrap();
http_stream.body_write_buf = BytesMut::from(&b"abc"[..]);
let res = http_stream.write_body_buf().await;
assert_eq!(res.unwrap_err().etype(), &WriteTimedout);
}
#[tokio::test]
async fn test_write_continue_resp() {
let wire = b"HTTP/1.1 100 Continue\r\n\r\n";
@ -1610,6 +1782,48 @@ mod tests_stream {
response.set_version(http::Version::HTTP_11);
assert!(!is_upgrade_resp(&response));
}
#[test]
fn test_get_write_timeout() {
let mut http_stream = HttpSession::new(Box::new(Builder::new().build()));
let expected = Duration::from_secs(5);
http_stream.set_write_timeout(expected);
assert_eq!(Some(expected), http_stream.write_timeout(50));
}
#[test]
fn test_get_write_timeout_none() {
let http_stream = HttpSession::new(Box::new(Builder::new().build()));
assert!(http_stream.write_timeout(50).is_none());
}
#[test]
fn test_get_write_timeout_min_send_rate_zero_noop() {
let mut http_stream = HttpSession::new(Box::new(Builder::new().build()));
http_stream.set_min_send_rate(0);
assert!(http_stream.write_timeout(50).is_none());
}
#[test]
fn test_get_write_timeout_min_send_rate_overrides_write_timeout() {
let mut http_stream = HttpSession::new(Box::new(Builder::new().build()));
let expected = Duration::from_millis(29800);
http_stream.set_write_timeout(Duration::from_secs(60));
http_stream.set_min_send_rate(5000);
assert_eq!(Some(expected), http_stream.write_timeout(149000));
}
#[test]
fn test_get_write_timeout_min_send_rate_max_zero_buf() {
let mut http_stream = HttpSession::new(Box::new(Builder::new().build()));
let expected = Duration::from_secs(1);
http_stream.set_min_send_rate(1);
assert_eq!(Some(expected), http_stream.write_timeout(0));
}
}
#[cfg(test)]

View file

@ -305,11 +305,18 @@ impl Http2Session {
/// Return the [Digest] of the connection
///
/// For reused connection, the timing in the digest will reflect its initial handshakes
/// The caller should check if the connection is reused to avoid misuse the timing field
/// The caller should check if the connection is reused to avoid misuse the timing field.
pub fn digest(&self) -> Option<&Digest> {
Some(self.conn.digest())
}
/// Return a mutable [Digest] reference for the connection, see [`digest`] for more details.
///
/// Will return `None` if multiple H2 streams are open.
pub fn digest_mut(&mut self) -> Option<&mut Digest> {
self.conn.digest_mut()
}
/// Return the server (peer) address recorded in the connection digest.
pub fn server_addr(&self) -> Option<&SocketAddr> {
self.conn
@ -342,6 +349,19 @@ impl Http2Session {
if self.ping_timedout() {
e.etype = PING_TIMEDOUT;
}
// is_go_away: retry via another connection, this connection is being teardown
// should retry
if self.response_header.is_none() {
if let Some(err) = e.root_cause().downcast_ref::<h2::Error>() {
if err.is_go_away()
&& err.is_remote()
&& err.reason().map_or(false, |r| r == h2::Reason::NO_ERROR)
{
e.retry = true.into();
}
}
}
e
}
}
@ -360,7 +380,7 @@ pub fn write_body(send_body: &mut SendStream<Bytes>, data: Bytes, end: bool) ->
/* Types of errors during h2 header read
1. peer requests to downgrade to h1, mostly IIS server for NTLM: we will downgrade and retry
2. peer sends invalid h2 frames, usually sending h1 only header: we will downgrade and retry
3. peer sends GO_AWAY(NO_ERROR) on reused conn, usually hit http2_max_requests: we will retry
3. peer sends GO_AWAY(NO_ERROR) connection is being shut down: we will retry
4. peer IO error on reused conn, usually firewall kills old conn: we will retry
5. All other errors will terminate the request
*/
@ -386,9 +406,8 @@ fn handle_read_header_error(e: h2::Error) -> Box<Error> {
&& e.reason().map_or(false, |r| r == h2::Reason::NO_ERROR)
{
// is_go_away: retry via another connection, this connection is being teardown
// only retry if the connection is reused
let mut err = Error::because(H2Error, "while reading h2 header", e);
err.retry = RetryType::ReusedOnly;
err.retry = true.into();
err
} else if e.is_io() {
// is_io: typical if a previously reused connection silently drops it

View file

@ -194,22 +194,21 @@ impl HttpSession {
return Ok(());
}
// FIXME: we should ignore 1xx header because send_response() can only be called once
// https://github.com/hyperium/h2/issues/167
if let Some(resp) = self.response_written.as_ref() {
if !resp.status.is_informational() {
warn!("Respond header is already sent, cannot send again");
return Ok(());
}
if header.status.is_informational() {
// ignore informational response 1xx header because send_response() can only be called once
// https://github.com/hyperium/h2/issues/167
debug!("ignoring informational headers");
return Ok(());
}
// no need to add these headers to 1xx responses
if !header.status.is_informational() {
/* update headers */
header.insert_header(header::DATE, get_cached_date())?;
if self.response_written.as_ref().is_some() {
warn!("Response header is already sent, cannot send again");
return Ok(());
}
/* update headers */
header.insert_header(header::DATE, get_cached_date())?;
// remove other h1 hop headers that cannot be present in H2
// https://httpwg.org/specs/rfc7540.html#n-connection-specific-header-fields
header.remove_header(&header::TRANSFER_ENCODING);
@ -454,6 +453,11 @@ impl HttpSession {
Some(&self.digest)
}
/// Return a mutable [Digest] reference for the connection.
pub fn digest_mut(&mut self) -> Option<&mut Digest> {
Arc::get_mut(&mut self.digest)
}
/// Return the server (local) address recorded in the connection digest.
pub fn server_addr(&self) -> Option<&SocketAddr> {
self.digest.socket_digest.as_ref().map(|d| d.local_addr())?
@ -481,7 +485,8 @@ mod test {
expected_trailers.insert("test", HeaderValue::from_static("trailers"));
let trailers = expected_trailers.clone();
tokio::spawn(async move {
let mut handles = vec![];
handles.push(tokio::spawn(async move {
let (h2, connection) = h2::client::handshake(client).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
@ -505,7 +510,7 @@ mod test {
assert_eq!(data, server_body);
let resp_trailers = body.trailers().await.unwrap().unwrap();
assert_eq!(resp_trailers, expected_trailers);
});
}));
let mut connection = handshake(Box::new(server), None).await.unwrap();
let digest = Arc::new(Digest::default());
@ -515,7 +520,7 @@ mod test {
.unwrap()
{
let trailers = trailers.clone();
tokio::spawn(async move {
handles.push(tokio::spawn(async move {
let req = http.req_header();
assert_eq!(req.method, Method::GET);
assert_eq!(req.uri, "https://www.example.com/");
@ -540,7 +545,11 @@ mod test {
}
let response_header = Box::new(ResponseHeader::build(200, None).unwrap());
http.write_response_header(response_header, false).unwrap();
assert!(http
.write_response_header(response_header.clone(), false)
.is_ok());
// this write should be ignored otherwise we will error
assert!(http.write_response_header(response_header, false).is_ok());
// test idling after response header is sent
tokio::select! {
@ -554,7 +563,11 @@ mod test {
http.write_trailers(trailers).unwrap();
http.finish().unwrap();
});
}));
}
for handle in handles {
// ensure no panics
assert!(handle.await.is_ok());
}
}
}

View file

@ -275,6 +275,31 @@ pub fn set_tcp_fastopen_backlog(_fd: RawFd, _backlog: usize) -> Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
pub fn set_dscp(fd: RawFd, value: u8) -> Result<()> {
use super::socket::SocketAddr;
use pingora_error::OkOrErr;
let sock = SocketAddr::from_raw_fd(fd, false);
let addr = sock
.as_ref()
.and_then(|s| s.as_inet())
.or_err(SocketError, "failed to set dscp, invalid IP socket")?;
if addr.is_ipv6() {
set_opt(fd, libc::IPPROTO_IPV6, libc::IPV6_TCLASS, value as c_int)
.or_err(SocketError, "failed to set dscp (IPV6_TCLASS)")
} else {
set_opt(fd, libc::IPPROTO_IP, libc::IP_TOS, value as c_int)
.or_err(SocketError, "failed to set dscp (IP_TOS)")
}
}
#[cfg(not(target_os = "linux"))]
pub fn set_dscp(_fd: RawFd, _value: u8) -> Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
pub fn get_socket_cookie(fd: RawFd) -> io::Result<u64> {
get_opt_sized::<c_ulonglong>(fd, libc::SOL_SOCKET, libc::SO_COOKIE)
@ -344,13 +369,10 @@ fn wrap_os_connect_error(e: std::io::Error, context: String) -> Box<Error> {
Error::because(InternalError, context, e)
}
_ => match e.raw_os_error() {
Some(code) => match code {
libc::ENETUNREACH | libc::EHOSTUNREACH => {
Error::because(ConnectNoRoute, context, e)
}
_ => Error::because(ConnectError, context, e),
},
None => Error::because(ConnectError, context, e),
Some(libc::ENETUNREACH | libc::EHOSTUNREACH) => {
Error::because(ConnectNoRoute, context, e)
}
_ => Error::because(ConnectError, context, e),
},
}
}

View file

@ -15,6 +15,7 @@
//! Generic socket type
use crate::{Error, OrErr};
use log::warn;
use nix::sys::socket::{getpeername, getsockname, SockaddrStorage};
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
@ -174,14 +175,23 @@ impl std::str::FromStr for SocketAddr {
type Err = Box<Error>;
// This is very basic parsing logic, it might treat invalid IP:PORT str as UDS path
// TODO: require UDS to have some prefix
fn from_str(s: &str) -> Result<Self, Self::Err> {
match StdSockAddr::from_str(s) {
Ok(addr) => Ok(SocketAddr::Inet(addr)),
Err(_) => {
let uds_socket = StdUnixSockAddr::from_pathname(s)
.or_err(crate::BindError, "invalid UDS path")?;
Ok(SocketAddr::Unix(uds_socket))
if s.starts_with("unix:") {
// format unix:/tmp/server.socket
let path = s.trim_start_matches("unix:");
let uds_socket = StdUnixSockAddr::from_pathname(path)
.or_err(crate::BindError, "invalid UDS path")?;
Ok(SocketAddr::Unix(uds_socket))
} else {
match StdSockAddr::from_str(s) {
Ok(addr) => Ok(SocketAddr::Inet(addr)),
Err(_) => {
// Try to parse as UDS for backward compatibility
let uds_socket = StdUnixSockAddr::from_pathname(s)
.or_err(crate::BindError, "invalid UDS path")?;
warn!("Raw Unix domain socket path support will be deprecated, add 'unix:' prefix instead");
Ok(SocketAddr::Unix(uds_socket))
}
}
}
}
@ -246,4 +256,10 @@ mod test {
let uds: SocketAddr = "/tmp/my.sock".parse().unwrap();
assert!(uds.as_unix().is_some());
}
#[test]
fn parse_uds_with_prefix() {
let uds: SocketAddr = "unix:/tmp/my.sock".parse().unwrap();
assert!(uds.as_unix().is_some());
}
}

View file

@ -29,6 +29,7 @@ pub use ssl::ALPN;
use async_trait::async_trait;
use std::fmt::Debug;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
/// Define how a protocol should shutdown its connection.
@ -231,10 +232,10 @@ impl ConnFdReusable for Path {
Ok(peer) => match UnixAddr::new(self) {
Ok(addr) => {
if addr == peer {
debug!("Unix FD to: {peer:?} is reusable");
debug!("Unix FD to: {peer} is reusable");
true
} else {
error!("Crit: unix FD mismatch: fd: {fd:?}, peer: {peer:?}, addr: {addr}",);
error!("Crit: unix FD mismatch: fd: {fd:?}, peer: {peer}, addr: {addr}",);
false
}
}
@ -256,12 +257,20 @@ impl ConnFdReusable for InetSocketAddr {
let fd = fd.as_raw_fd();
match getpeername::<SockaddrStorage>(fd) {
Ok(peer) => {
const ZERO: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0));
if self.ip() == ZERO {
// https://www.rfc-editor.org/rfc/rfc1122.html#section-3.2.1.3
// 0.0.0.0 should only be used as source IP not destination
// However in some systems this destination IP is mapped to 127.0.0.1.
// We just skip this check here to avoid false positive mismatch.
return true;
}
let addr = SockaddrStorage::from(*self);
if addr == peer {
debug!("Inet FD to: {peer:?} is reusable");
debug!("Inet FD to: {addr} is reusable");
true
} else {
error!("Crit: FD mismatch: fd: {fd:?}, addr: {addr:?}, peer: {peer:?}",);
error!("Crit: FD mismatch: fd: {fd:?}, addr: {addr}, peer: {peer}",);
false
}
}

View file

@ -68,9 +68,13 @@ where
H: Iterator<Item = (S, &'a Vec<u8>)>,
{
// TODO: valid that host doesn't have port
// TODO: support adding ad-hoc headers
let authority = format!("{host}:{port}");
let authority = if host.parse::<std::net::Ipv6Addr>().is_ok() {
format!("[{host}]:{port}")
} else {
format!("{host}:{port}")
};
let req = http::request::Builder::new()
.version(http::Version::HTTP_11)
.method(http::method::Method::CONNECT)
@ -217,6 +221,19 @@ mod test_sync {
assert_eq!(req.headers.get("Host").unwrap(), "pingora.org:123");
assert_eq!(req.headers.get("foo").unwrap(), "bar");
}
#[test]
fn test_generate_connect_header_ipv6() {
let mut headers = BTreeMap::new();
headers.insert(String::from("foo"), b"bar".to_vec());
let req = generate_connect_header("::1", 123, headers.iter()).unwrap();
assert_eq!(req.method, http::method::Method::CONNECT);
assert_eq!(req.uri.authority().unwrap(), "[::1]:123");
assert_eq!(req.headers.get("Host").unwrap(), "[::1]:123");
assert_eq!(req.headers.get("foo").unwrap(), "bar");
}
#[test]
fn test_request_to_wire_auth_form() {
let new_request = http::Request::builder()

View file

@ -119,7 +119,7 @@ impl Default for ServerConf {
/// Command-line options
///
/// Call `Opt::from_args()` to build this object from the process's command line arguments.
#[derive(Parser, Debug)]
#[derive(Parser, Debug, Default)]
#[clap(name = "basic", long_about = None)]
pub struct Opt {
/// Whether this server should try to upgrade from a running old server
@ -163,15 +163,6 @@ pub struct Opt {
pub conf: Option<String>,
}
/// Create the default instance of Opt based on the current command-line args.
/// This is equivalent to running `Opt::parse` but does not require the
/// caller to have included the `clap::Parser`
impl Default for Opt {
fn default() -> Self {
Opt::parse()
}
}
impl ServerConf {
// Does not has to be async until we want runtime reload
pub fn load_from_yaml<P>(path: P) -> Result<Self>
@ -236,6 +227,15 @@ impl ServerConf {
}
}
/// Create an instance of Opt by parsing the current command-line args.
/// This is equivalent to running `Opt::parse` but does not require the
/// caller to have included the `clap::Parser`
impl Opt {
pub fn parse_args() -> Self {
Opt::parse()
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -26,7 +26,7 @@ use super::Service;
use crate::server::{ListenFds, ShutdownWatch};
/// The background service interface
#[cfg_attr(not(doc_async_trait), async_trait)]
#[async_trait]
pub trait BackgroundService {
/// This function is called when the pingora server tries to start all the
/// services. The background service can return at anytime or wait for the

View file

@ -29,6 +29,7 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use crate::connectors::L4Connect;
use crate::protocols::l4::socket::SocketAddr;
use crate::protocols::ConnFdReusable;
use crate::protocols::TcpKeepalive;
@ -172,6 +173,12 @@ pub trait Peer: Display + Clone {
self.get_peer_options().and_then(|o| o.tcp_recv_buf)
}
/// The DSCP value that should be applied to the send side of this connection.
/// See the [RFC](https://datatracker.ietf.org/doc/html/rfc2474) for more details.
fn dscp(&self) -> Option<u8> {
self.get_peer_options().and_then(|o| o.dscp)
}
/// Whether to enable TCP fast open.
fn tcp_fast_open(&self) -> bool {
self.get_peer_options()
@ -301,6 +308,7 @@ pub struct PeerOptions {
pub ca: Option<Arc<Box<[X509]>>>,
pub tcp_keepalive: Option<TcpKeepalive>,
pub tcp_recv_buf: Option<usize>,
pub dscp: Option<u8>,
pub no_header_eos: bool,
pub h2_ping_interval: Option<Duration>,
// how many concurrent h2 stream are allowed in the same connection
@ -315,6 +323,8 @@ pub struct PeerOptions {
pub tcp_fast_open: bool,
// use Arc because Clone is required but not allowed in trait object
pub tracer: Option<Tracer>,
// A custom L4 connector to use to establish new L4 connections
pub custom_l4: Option<Arc<dyn L4Connect + Send + Sync>>,
}
impl PeerOptions {
@ -334,6 +344,7 @@ impl PeerOptions {
ca: None,
tcp_keepalive: None,
tcp_recv_buf: None,
dscp: None,
no_header_eos: false,
h2_ping_interval: None,
max_h2_streams: 1,
@ -342,6 +353,7 @@ impl PeerOptions {
second_keyshare: true, // default true and noop when not using PQ curves
tcp_fast_open: false,
tracer: None,
custom_l4: None,
}
}
@ -403,6 +415,9 @@ pub struct HttpPeer {
pub sni: String,
pub proxy: Option<Proxy>,
pub client_cert_key: Option<Arc<CertKey>>,
/// a custom field to isolate connection reuse. Requests with different group keys
/// cannot share connections with each other.
pub group_key: u64,
pub options: PeerOptions,
}
@ -422,6 +437,7 @@ impl HttpPeer {
sni,
proxy: None,
client_cert_key: None,
group_key: 0,
options: PeerOptions::new(),
}
}
@ -462,6 +478,7 @@ impl HttpPeer {
headers,
}),
client_cert_key: None,
group_key: 0,
options: PeerOptions::new(),
}
}
@ -485,6 +502,7 @@ impl Hash for HttpPeer {
self.verify_cert().hash(state);
self.verify_hostname().hash(state);
self.alternative_cn().hash(state);
self.group_key.hash(state);
}
}

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-error"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"

View file

@ -509,6 +509,13 @@ pub trait OrErr<T, E> {
et: ErrorType,
context: F,
) -> Result<T, BError>;
/// Similar to or_err() but just to surface errors that are not [Error] (where `?` cannot be used directly).
///
/// or_err()/or_err_with() are still preferred because they make the error more readable and traceable.
fn or_fail(self) -> Result<T>
where
E: Into<Box<dyn ErrorTrait + Send + Sync>>;
}
impl<T, E> OrErr<T, E> for Result<T, E> {
@ -537,6 +544,13 @@ impl<T, E> OrErr<T, E> for Result<T, E> {
) -> Result<T, BError> {
self.map_err(|e| Error::explain(et, exp(e)))
}
fn or_fail(self) -> Result<T, BError>
where
E: Into<Box<dyn ErrorTrait + Send + Sync>>,
{
self.map_err(|e| Error::because(ErrorType::InternalError, "", e))
}
}
/// Helper trait to convert an [Option] to an [Error] with context.
@ -641,4 +655,19 @@ mod tests {
" InternalError context: none is an error!"
);
}
#[test]
fn test_into() {
fn other_error() -> Result<(), &'static str> {
Err("oops")
}
fn surface_err() -> Result<()> {
other_error().or_fail()?; // can return directly but want to showcase ?
Ok(())
}
let e = surface_err().unwrap_err();
assert_eq!(format!("{}", e), " InternalError context: cause: oops");
}
}

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-header-serde"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
@ -27,6 +27,6 @@ zstd-safe = { version = "7.1.0", features = ["std"] }
http = { workspace = true }
bytes = { workspace = true }
httparse = { workspace = true }
pingora-error = { version = "0.2.0", path = "../pingora-error" }
pingora-http = { version = "0.2.0", path = "../pingora-http" }
pingora-error = { version = "0.3.0", path = "../pingora-error" }
pingora-http = { version = "0.3.0", path = "../pingora-http" }
thread_local = "1.0"

View file

@ -131,7 +131,7 @@ fn resp_header_to_buf(resp: &ResponseHeader, buf: &mut Vec<u8>) -> usize {
}
// Should match pingora http1 setting
const MAX_HEADERS: usize = 160;
const MAX_HEADERS: usize = 256;
#[inline]
fn buf_to_http_header(buf: &[u8]) -> Result<ResponseHeader> {

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-http"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
@ -19,7 +19,7 @@ path = "src/lib.rs"
[dependencies]
http = { workspace = true }
bytes = { workspace = true }
pingora-error = { version = "0.2.0", path = "../pingora-error" }
pingora-error = { version = "0.3.0", path = "../pingora-error" }
[features]
default = []

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-ketama"
version = "0.2.0"
version = "0.3.0"
description = "Rust port of the nginx consistent hash function"
authors = ["Pingora Team <pingora@cloudflare.com>"]
license = "Apache-2.0"

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-limits"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
description = "A library for rate limiting and event frequency estimation"

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-load-balancing"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
@ -18,17 +18,19 @@ path = "src/lib.rs"
[dependencies]
async-trait = { workspace = true }
pingora-http = { version = "0.2.0", path = "../pingora-http" }
pingora-error = { version = "0.2.0", path = "../pingora-error" }
pingora-core = { version = "0.2.0", path = "../pingora-core", default-features = false }
pingora-ketama = { version = "0.2.0", path = "../pingora-ketama" }
pingora-runtime = { version = "0.2.0", path = "../pingora-runtime" }
pingora-http = { version = "0.3.0", path = "../pingora-http" }
pingora-error = { version = "0.3.0", path = "../pingora-error" }
pingora-core = { version = "0.3.0", path = "../pingora-core", default-features = false }
pingora-ketama = { version = "0.3.0", path = "../pingora-ketama" }
pingora-runtime = { version = "0.3.0", path = "../pingora-runtime" }
arc-swap = "1"
fnv = "1"
rand = "0"
tokio = { workspace = true }
futures = "0"
log = { workspace = true }
http = { workspace = true }
derivative = "2.2.0"
[dev-dependencies]

View file

@ -16,6 +16,7 @@
use arc_swap::ArcSwap;
use async_trait::async_trait;
use http::Extensions;
use pingora_core::protocols::l4::socket::SocketAddr;
use pingora_error::Result;
use std::io::Result as IoResult;
@ -62,6 +63,7 @@ impl Static {
let addrs = addrs.to_socket_addrs()?.map(|addr| Backend {
addr: SocketAddr::Inet(addr),
weight: 1,
ext: Extensions::new(),
});
upstreams.extend(addrs);
}

View file

@ -24,6 +24,16 @@ use pingora_http::{RequestHeader, ResponseHeader};
use std::sync::Arc;
use std::time::Duration;
/// [HealthObserve] is an interface for observing health changes of backends,
/// this is what's used for our health observation callback.
#[async_trait]
pub trait HealthObserve {
/// Observes the health of a [Backend], can be used for monitoring purposes.
async fn observe(&self, target: &Backend, healthy: bool);
}
/// Provided to a [HealthCheck] to observe changes to [Backend] health.
pub type HealthObserveCallback = Box<dyn HealthObserve + Send + Sync>;
/// [HealthCheck] is the interface to implement health check for backends
#[async_trait]
pub trait HealthCheck {
@ -31,6 +41,10 @@ pub trait HealthCheck {
///
/// `Ok(())`` if the check passes, otherwise the check fails.
async fn check(&self, target: &Backend) -> Result<()>;
/// Called when the health changes for a [Backend].
async fn health_status_change(&self, _target: &Backend, _healthy: bool) {}
/// This function defines how many *consecutive* checks should flip the health of a backend.
///
/// For example: with `success``: `true`: this function should return the
@ -56,6 +70,8 @@ pub struct TcpHealthCheck {
/// set, it will also try to establish a TLS connection on top of the TCP connection.
pub peer_template: BasicPeer,
connector: TransportConnector,
/// A callback that is invoked when the `healthy` status changes for a [Backend].
pub health_changed_callback: Option<HealthObserveCallback>,
}
impl Default for TcpHealthCheck {
@ -67,6 +83,7 @@ impl Default for TcpHealthCheck {
consecutive_failure: 1,
peer_template,
connector: TransportConnector::new(None),
health_changed_callback: None,
}
}
}
@ -110,6 +127,12 @@ impl HealthCheck for TcpHealthCheck {
peer._address = target.addr.clone();
self.connector.get_stream(&peer).await.map(|_| {})
}
async fn health_status_change(&self, target: &Backend, healthy: bool) {
if let Some(callback) = &self.health_changed_callback {
callback.observe(target, healthy).await;
}
}
}
type Validator = Box<dyn Fn(&ResponseHeader) -> Result<()> + Send + Sync>;
@ -133,9 +156,9 @@ pub struct HttpHealthCheck {
/// Whether the underlying TCP/TLS connection can be reused across checks.
///
/// * `false` will make sure that every health check goes through TCP (and TLS) handshakes.
/// Established connections sometimes hide the issue of firewalls and L4 LB.
/// Established connections sometimes hide the issue of firewalls and L4 LB.
/// * `true` will try to reuse connections across checks, this is the more efficient and fast way
/// to perform health checks.
/// to perform health checks.
pub reuse_connection: bool,
/// The request header to send to the backend
pub req: RequestHeader,
@ -147,6 +170,8 @@ pub struct HttpHealthCheck {
/// Sometimes the health check endpoint lives one a different port than the actual backend.
/// Setting this option allows the health check to perform on the given port of the backend IP.
pub port_override: Option<u16>,
/// A callback that is invoked when the `healthy` status changes for a [Backend].
pub health_changed_callback: Option<HealthObserveCallback>,
}
impl HttpHealthCheck {
@ -174,6 +199,7 @@ impl HttpHealthCheck {
req,
validator: None,
port_override: None,
health_changed_callback: None,
}
}
@ -235,6 +261,11 @@ impl HealthCheck for HttpHealthCheck {
Ok(())
}
async fn health_status_change(&self, target: &Backend, healthy: bool) {
if let Some(callback) = &self.health_changed_callback {
callback.observe(target, healthy).await;
}
}
}
#[derive(Clone)]
@ -313,8 +344,15 @@ impl Health {
#[cfg(test)]
mod test {
use std::{
collections::{BTreeSet, HashMap},
sync::atomic::{AtomicU16, Ordering},
};
use super::*;
use crate::SocketAddr;
use crate::{discovery, Backends, SocketAddr};
use async_trait::async_trait;
use http::Extensions;
#[tokio::test]
async fn test_tcp_check() {
@ -323,6 +361,7 @@ mod test {
let backend = Backend {
addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
weight: 1,
ext: Extensions::new(),
};
assert!(tcp_check.check(&backend).await.is_ok());
@ -330,6 +369,7 @@ mod test {
let backend = Backend {
addr: SocketAddr::Inet("1.1.1.1:79".parse().unwrap()),
weight: 1,
ext: Extensions::new(),
};
assert!(tcp_check.check(&backend).await.is_err());
@ -341,6 +381,7 @@ mod test {
let backend = Backend {
addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
weight: 1,
ext: Extensions::new(),
};
assert!(tls_check.check(&backend).await.is_ok());
@ -353,6 +394,7 @@ mod test {
let backend = Backend {
addr: SocketAddr::Inet("1.1.1.1:443".parse().unwrap()),
weight: 1,
ext: Extensions::new(),
};
assert!(https_check.check(&backend).await.is_ok());
@ -375,10 +417,85 @@ mod test {
let backend = Backend {
addr: SocketAddr::Inet("1.1.1.1:80".parse().unwrap()),
weight: 1,
ext: Extensions::new(),
};
http_check.check(&backend).await.unwrap();
assert!(http_check.check(&backend).await.is_ok());
}
#[tokio::test]
async fn test_health_observe() {
struct Observe {
unhealthy_count: Arc<AtomicU16>,
}
#[async_trait]
impl HealthObserve for Observe {
async fn observe(&self, _target: &Backend, healthy: bool) {
if !healthy {
self.unhealthy_count.fetch_add(1, Ordering::Relaxed);
}
}
}
let good_backend = Backend::new("127.0.0.1:79").unwrap();
let new_good_backends = || -> (BTreeSet<Backend>, HashMap<u64, bool>) {
let mut healthy = HashMap::new();
healthy.insert(good_backend.hash_key(), true);
let mut backends = BTreeSet::new();
backends.extend(vec![good_backend.clone()]);
(backends, healthy)
};
// tcp health check
{
let unhealthy_count = Arc::new(AtomicU16::new(0));
let ob = Observe {
unhealthy_count: unhealthy_count.clone(),
};
let bob = Box::new(ob);
let tcp_check = TcpHealthCheck {
health_changed_callback: Some(bob),
..Default::default()
};
let discovery = discovery::Static::default();
let mut backends = Backends::new(Box::new(discovery));
backends.set_health_check(Box::new(tcp_check));
let result = new_good_backends();
backends.do_update(result.0, result.1, |_backend: Arc<BTreeSet<Backend>>| {});
// the backend is ready
assert!(backends.ready(&good_backend));
// run health check
backends.run_health_check(false).await;
assert!(1 == unhealthy_count.load(Ordering::Relaxed));
// backend is unhealthy
assert!(!backends.ready(&good_backend));
}
// http health check
{
let unhealthy_count = Arc::new(AtomicU16::new(0));
let ob = Observe {
unhealthy_count: unhealthy_count.clone(),
};
let bob = Box::new(ob);
let mut https_check = HttpHealthCheck::new("one.one.one.one", true);
https_check.health_changed_callback = Some(bob);
let discovery = discovery::Static::default();
let mut backends = Backends::new(Box::new(discovery));
backends.set_health_check(Box::new(https_check));
let result = new_good_backends();
backends.do_update(result.0, result.1, |_backend: Arc<BTreeSet<Backend>>| {});
// the backend is ready
assert!(backends.ready(&good_backend));
// run health check
backends.run_health_check(false).await;
assert!(1 == unhealthy_count.load(Ordering::Relaxed));
assert!(!backends.ready(&good_backend));
}
}
}

View file

@ -16,8 +16,14 @@
//! This crate provides common service discovery, health check and load balancing
//! algorithms for proxies to use.
// https://github.com/mcarton/rust-derivative/issues/112
// False positive for macro generated code
#![allow(clippy::non_canonical_partial_ord_impl)]
use arc_swap::ArcSwap;
use derivative::Derivative;
use futures::FutureExt;
pub use http::Extensions;
use pingora_core::protocols::l4::socket::SocketAddr;
use pingora_error::{ErrorType, OrErr, Result};
use std::collections::hash_map::DefaultHasher;
@ -45,25 +51,45 @@ pub mod prelude {
}
/// [Backend] represents a server to proxy or connect to.
#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
#[derive(Derivative)]
#[derivative(Clone, Hash, PartialEq, PartialOrd, Eq, Ord, Debug)]
pub struct Backend {
/// The address to the backend server.
pub addr: SocketAddr,
/// The relative weight of the server. Load balancing algorithms will
/// proportionally distributed traffic according to this value.
pub weight: usize,
/// The extension field to put arbitrary data to annotate the Backend.
/// The data added here is opaque to this crate hence the data is ignored by
/// functionalities of this crate. For example, two backends with the same
/// [SocketAddr] and the same weight but different `ext` data are considered
/// identical.
/// See [Extensions] for how to add and read the data.
#[derivative(PartialEq = "ignore")]
#[derivative(PartialOrd = "ignore")]
#[derivative(Hash = "ignore")]
#[derivative(Ord = "ignore")]
pub ext: Extensions,
}
impl Backend {
/// Create a new [Backend] with `weight` 1. The function will try to parse
/// `addr` into a [std::net::SocketAddr].
pub fn new(addr: &str) -> Result<Self> {
Self::new_with_weight(addr, 1)
}
/// Creates a new [Backend] with the specified `weight`. The function will try to parse
/// `addr` into a [std::net::SocketAddr].
pub fn new_with_weight(addr: &str, weight: usize) -> Result<Self> {
let addr = addr
.parse()
.or_err(ErrorType::InternalError, "invalid socket addr")?;
Ok(Backend {
addr: SocketAddr::Inet(addr),
weight: 1,
weight,
ext: Extensions::new(),
})
// TODO: UDS
}
@ -130,8 +156,17 @@ impl Backends {
self.health_check = Some(hc.into())
}
/// Return true when the new is different from the current set of backends
fn do_update(&self, new_backends: BTreeSet<Backend>, enablement: HashMap<u64, bool>) -> bool {
/// Updates backends when the new is different from the current set,
/// the callback will be invoked when the new set of backend is different
/// from the current one so that the caller can update the selector accordingly.
fn do_update<F>(
&self,
new_backends: BTreeSet<Backend>,
enablement: HashMap<u64, bool>,
callback: F,
) where
F: Fn(Arc<BTreeSet<Backend>>),
{
if (**self.backends.load()) != new_backends {
let old_health = self.health.load();
let mut health = HashMap::with_capacity(new_backends.len());
@ -147,10 +182,14 @@ impl Backends {
health.insert(hash_key, backend_health);
}
// TODO: put backend and health under 1 ArcSwap so that this update is atomic
self.backends.store(Arc::new(new_backends));
// TODO: put this all under 1 ArcSwap so the update is atomic
// It's important the `callback()` executes first since computing selector backends might
// be expensive. For example, if a caller checks `backends` to see if any are available
// they may encounter false positives if the selector isn't ready yet.
let new_backends = Arc::new(new_backends);
callback(new_backends.clone());
self.backends.store(new_backends);
self.health.store(Arc::new(health));
true
} else {
// no backend change, just check enablement
for (hash_key, backend_enabled) in enablement.iter() {
@ -160,7 +199,6 @@ impl Backends {
backend_health.enable(*backend_enabled);
}
}
false
}
}
@ -199,12 +237,15 @@ impl Backends {
/// Call the service discovery method to update the collection of backends.
///
/// Return `true` when the new collection is different from the current set of backends.
/// This return value is useful to tell the caller when to rebuild things that are expensive to
/// update, such as consistent hashing rings.
pub async fn update(&self) -> Result<bool> {
/// The callback will be invoked when the new set of backend is different
/// from the current one so that the caller can update the selector accordingly.
pub async fn update<F>(&self, callback: F) -> Result<()>
where
F: Fn(Arc<BTreeSet<Backend>>),
{
let (new_backends, enablement) = self.discovery.discover().await?;
Ok(self.do_update(new_backends, enablement))
self.do_update(new_backends, enablement, callback);
Ok(())
}
/// Run health check on all backends if it is set.
@ -225,6 +266,7 @@ impl Backends {
let flipped =
h.observe_health(errored.is_none(), check.health_threshold(errored.is_none()));
if flipped {
check.health_status_change(backend, errored.is_none()).await;
if let Some(e) = errored {
warn!("{backend:?} becomes unhealthy, {e}");
} else {
@ -320,11 +362,9 @@ where
/// This function will be called every `update_frequency` if this [LoadBalancer] instance
/// is running as a background service.
pub async fn update(&self) -> Result<()> {
if self.backends.update().await? {
self.selector
.store(Arc::new(S::build(&self.backends.get_backend())))
}
Ok(())
self.backends
.update(|backends| self.selector.store(Arc::new(S::build(&backends))))
.await
}
/// Return the first healthy [Backend] according to the selection algorithm and the
@ -378,6 +418,8 @@ where
#[cfg(test)]
mod test {
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
use super::*;
use async_trait::async_trait;
@ -408,10 +450,20 @@ mod test {
backends.set_health_check(check);
// true: new backend discovered
assert!(backends.update().await.unwrap());
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(updated.load(Relaxed));
// false: no new backend discovered
assert!(!backends.update().await.unwrap());
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(!updated.load(Relaxed));
backends.run_health_check(false).await;
@ -424,6 +476,31 @@ mod test {
assert!(backends.ready(&good2));
assert!(!backends.ready(&bad));
}
#[tokio::test]
async fn test_backends_with_ext() {
let discovery = discovery::Static::default();
let mut b1 = Backend::new("1.1.1.1:80").unwrap();
b1.ext.insert(true);
let mut b2 = Backend::new("1.0.0.1:80").unwrap();
b2.ext.insert(1u8);
discovery.add(b1.clone());
discovery.add(b2.clone());
let backends = Backends::new(Box::new(discovery));
// fill in the backends
backends.update(|_| {}).await.unwrap();
let backend = backends.get_backend();
assert!(backend.contains(&b1));
assert!(backend.contains(&b2));
let b2 = backend.first().unwrap();
assert_eq!(b2.ext.get::<u8>(), Some(&1));
let b1 = backend.last().unwrap();
assert_eq!(b1.ext.get::<bool>(), Some(&true));
}
#[tokio::test]
async fn test_discovery_readiness() {
@ -449,7 +526,14 @@ mod test {
let discovery = TestDiscovery(discovery);
let backends = Backends::new(Box::new(discovery));
assert!(backends.update().await.unwrap());
// true: new backend discovered
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(updated.load(Relaxed));
let backend = backends.get_backend();
assert!(backend.contains(&good1));
@ -476,7 +560,12 @@ mod test {
backends.set_health_check(check);
// true: new backend discovered
assert!(backends.update().await.unwrap());
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(updated.load(Relaxed));
backends.run_health_check(true).await;
@ -484,4 +573,46 @@ mod test {
assert!(backends.ready(&good2));
assert!(!backends.ready(&bad));
}
mod thread_safety {
use super::*;
struct MockDiscovery {
expected: usize,
}
#[async_trait]
impl ServiceDiscovery for MockDiscovery {
async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
let mut d = BTreeSet::new();
let mut m = HashMap::with_capacity(self.expected);
for i in 0..self.expected {
let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
m.insert(i as u64, true);
d.insert(b);
}
Ok((d, m))
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_consistency() {
let expected = 3000;
let discovery = MockDiscovery { expected };
let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
Backends::new(Box::new(discovery)),
));
let lb2 = lb.clone();
tokio::spawn(async move {
assert!(lb2.update().await.is_ok());
});
let mut backend_count = 0;
while backend_count == 0 {
let backends = lb.backends();
backend_count = backends.backends.load_full().len();
}
assert_eq!(backend_count, expected);
assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
}
}
}

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-lru"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-memory-cache"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
@ -17,11 +17,11 @@ name = "pingora_memory_cache"
path = "src/lib.rs"
[dependencies]
TinyUFO = { version = "0.2.0", path = "../tinyufo" }
TinyUFO = { version = "0.3.0", path = "../tinyufo" }
ahash = { workspace = true }
tokio = { workspace = true, features = ["sync"] }
async-trait = { workspace = true }
pingora-error = { version = "0.2.0", path = "../pingora-error" }
pingora-error = { version = "0.3.0", path = "../pingora-error" }
log = { workspace = true }
parking_lot = "0"
pingora-timeout = { version = "0.2.0", path = "../pingora-timeout" }
pingora-timeout = { version = "0.3.0", path = "../pingora-timeout" }

View file

@ -45,6 +45,14 @@ impl CacheStatus {
Self::LockHit => "lock_hit",
}
}
/// Returns whether this status represents a cache hit.
pub fn is_hit(&self) -> bool {
match self {
CacheStatus::Hit | CacheStatus::LockHit => true,
CacheStatus::Miss | CacheStatus::Expired => false,
}
}
}
#[derive(Debug, Clone)]

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-openssl"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"

View file

@ -27,7 +27,7 @@ use openssl_sys::{
use std::ffi::CString;
use std::os::raw;
fn cvt(r: c_int) -> Result<c_int, ErrorStack> {
fn cvt(r: c_long) -> Result<c_long, ErrorStack> {
if r != 1 {
Err(ErrorStack::get())
} else {
@ -42,8 +42,8 @@ extern "C" {
namelen: size_t,
) -> c_int;
pub fn SSL_use_certificate(ssl: *const SSL, cert: *mut X509) -> c_int;
pub fn SSL_use_PrivateKey(ctx: *const SSL, key: *mut EVP_PKEY) -> c_int;
pub fn SSL_use_certificate(ssl: *mut SSL, cert: *mut X509) -> c_int;
pub fn SSL_use_PrivateKey(ssl: *mut SSL, key: *mut EVP_PKEY) -> c_int;
pub fn SSL_set_cert_cb(
ssl: *mut SSL,
@ -64,9 +64,9 @@ pub fn add_host(verify_param: &mut X509VerifyParamRef, host: &str) -> Result<(),
unsafe {
cvt(X509_VERIFY_PARAM_add1_host(
verify_param.as_ptr(),
host.as_ptr() as *const _,
host.as_ptr() as *const c_char,
host.len(),
))
) as c_long)
.map(|_| ())
}
}
@ -84,7 +84,7 @@ pub fn ssl_set_verify_cert_store(
SSL_CTRL_SET_VERIFY_CERT_STORE,
1, // increase the ref count of X509Store so that ssl_ctx can outlive X509StoreRef
cert_store.as_ptr() as *mut c_void,
) as i32)?;
))?;
}
Ok(())
}
@ -94,7 +94,7 @@ pub fn ssl_set_verify_cert_store(
/// See [SSL_use_certificate](https://www.openssl.org/docs/man1.1.1/man3/SSL_use_certificate.html).
pub fn ssl_use_certificate(ssl: &mut SslRef, cert: &X509Ref) -> Result<(), ErrorStack> {
unsafe {
cvt(SSL_use_certificate(ssl.as_ptr(), cert.as_ptr()))?;
cvt(SSL_use_certificate(ssl.as_ptr() as *mut SSL, cert.as_ptr() as *mut X509) as c_long)?;
}
Ok(())
}
@ -107,7 +107,7 @@ where
T: HasPrivate,
{
unsafe {
cvt(SSL_use_PrivateKey(ssl.as_ptr(), key.as_ptr()))?;
cvt(SSL_use_PrivateKey(ssl.as_ptr() as *mut SSL, key.as_ptr() as *mut EVP_PKEY) as c_long)?;
}
Ok(())
}
@ -123,7 +123,7 @@ pub fn ssl_add_chain_cert(ssl: &mut SslRef, cert: &X509Ref) -> Result<(), ErrorS
SSL_CTRL_CHAIN_CERT,
1, // increase the ref count of X509 so that ssl can outlive X509StoreRef
cert.as_ptr() as *mut c_void,
) as i32)?;
))?;
}
Ok(())
}
@ -137,14 +137,17 @@ pub fn ssl_set_renegotiate_mode_freely(_ssl: &mut SslRef) {}
///
/// See [set_groups_list](https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set1_curves.html).
pub fn ssl_set_groups_list(ssl: &mut SslRef, groups: &str) -> Result<(), ErrorStack> {
let groups = CString::new(groups).unwrap();
if groups.contains('\0') {
return Err(ErrorStack::get());
}
let groups = CString::new(groups).map_err(|_| ErrorStack::get())?;
unsafe {
cvt(SSL_ctrl(
ssl.as_ptr(),
SSL_CTRL_SET_GROUPS_LIST,
0,
groups.as_ptr() as *mut c_void,
) as i32)?;
))?;
}
Ok(())
}
@ -207,3 +210,22 @@ pub fn is_suspended_for_cert(error: &openssl::ssl::Error) -> bool {
pub unsafe fn ssl_mut(ssl: &SslRef) -> &mut SslRef {
SslRef::from_ptr_mut(ssl.as_ptr())
}
#[cfg(test)]
mod tests {
use super::*;
use openssl::ssl::{SslContextBuilder, SslMethod};
#[test]
fn test_ssl_set_groups_list() {
let ctx_builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
let ssl = Ssl::new(&ctx_builder.build()).unwrap();
let ssl_ref = unsafe { ssl_mut(&ssl) };
// Valid input
assert!(ssl_set_groups_list(ssl_ref, "P-256:P-384").is_ok());
// Invalid input (contains null byte)
assert!(ssl_set_groups_list(ssl_ref, "P-256\0P-384").is_err());
}
}

View file

@ -25,6 +25,7 @@ pub use tokio_openssl as tokio_ssl;
pub mod ext;
// export commonly used libs
pub use ssl_lib::dh;
pub use ssl_lib::error;
pub use ssl_lib::hash;
pub use ssl_lib::nid;

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-pool"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
@ -23,7 +23,7 @@ lru = { workspace = true }
log = { workspace = true }
parking_lot = "0.12"
crossbeam-queue = "0.3"
pingora-timeout = { version = "0.2.0", path = "../pingora-timeout" }
pingora-timeout = { version = "0.3.0", path = "../pingora-timeout" }
[dev-dependencies]
tokio-test = "0.4"

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-proxy"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
@ -18,12 +18,12 @@ name = "pingora_proxy"
path = "src/lib.rs"
[dependencies]
pingora-error = { version = "0.2.0", path = "../pingora-error" }
pingora-core = { version = "0.2.0", path = "../pingora-core", default-features = false }
pingora-timeout = { version = "0.2.0", path = "../pingora-timeout" }
pingora-cache = { version = "0.2.0", path = "../pingora-cache", default-features = false }
pingora-error = { version = "0.3.0", path = "../pingora-error" }
pingora-core = { version = "0.3.0", path = "../pingora-core", default-features = false }
pingora-timeout = { version = "0.3.0", path = "../pingora-timeout" }
pingora-cache = { version = "0.3.0", path = "../pingora-cache", default-features = false }
tokio = { workspace = true, features = ["macros", "net"] }
pingora-http = { version = "0.2.0", path = "../pingora-http" }
pingora-http = { version = "0.3.0", path = "../pingora-http" }
http = { workspace = true }
futures = "0.3"
bytes = { workspace = true }
@ -44,7 +44,8 @@ env_logger = "0.9"
hyperlocal = "0.8"
hyper = "0.14"
tokio-tungstenite = "0.20.1"
pingora-load-balancing = { version = "0.2.0", path = "../pingora-load-balancing" }
pingora-limits = { version = "0.3.0", path = "../pingora-limits" }
pingora-load-balancing = { version = "0.3.0", path = "../pingora-load-balancing" }
prometheus = "0"
futures-util = "0.3"
serde = { version = "1.0", features = ["derive"] }
@ -55,3 +56,10 @@ serde_yaml = "0.8"
default = ["openssl"]
openssl = ["pingora-core/openssl", "pingora-cache/openssl"]
boringssl = ["pingora-core/boringssl", "pingora-cache/boringssl"]
# or locally cargo doc --config "build.rustdocflags='--cfg doc_async_trait'"
[package.metadata.docs.rs]
rustdoc-args = ["--cfg", "doc_async_trait"]
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(doc_async_trait)'] }

View file

@ -0,0 +1,117 @@
use async_trait::async_trait;
use once_cell::sync::Lazy;
use pingora_core::prelude::*;
use pingora_http::{RequestHeader, ResponseHeader};
use pingora_limits::rate::Rate;
use pingora_load_balancing::prelude::{RoundRobin, TcpHealthCheck};
use pingora_load_balancing::LoadBalancer;
use pingora_proxy::{http_proxy_service, ProxyHttp, Session};
use std::sync::Arc;
use std::time::Duration;
fn main() {
let mut server = Server::new(Some(Opt::default())).unwrap();
server.bootstrap();
let mut upstreams = LoadBalancer::try_from_iter(["1.1.1.1:443", "1.0.0.1:443"]).unwrap();
// Set health check
let hc = TcpHealthCheck::new();
upstreams.set_health_check(hc);
upstreams.health_check_frequency = Some(Duration::from_secs(1));
// Set background service
let background = background_service("health check", upstreams);
let upstreams = background.task();
// Set load balancer
let mut lb = http_proxy_service(&server.configuration, LB(upstreams));
lb.add_tcp("0.0.0.0:6188");
// let rate = Rate
server.add_service(background);
server.add_service(lb);
server.run_forever();
}
pub struct LB(Arc<LoadBalancer<RoundRobin>>);
impl LB {
pub fn get_request_appid(&self, session: &mut Session) -> Option<String> {
match session
.req_header()
.headers
.get("appid")
.map(|v| v.to_str())
{
None => None,
Some(v) => match v {
Ok(v) => Some(v.to_string()),
Err(_) => None,
},
}
}
}
// Rate limiter
static RATE_LIMITER: Lazy<Rate> = Lazy::new(|| Rate::new(Duration::from_secs(1)));
// max request per second per client
static MAX_REQ_PER_SEC: isize = 1;
#[async_trait]
impl ProxyHttp for LB {
type CTX = ();
fn new_ctx(&self) {}
async fn upstream_peer(
&self,
_session: &mut Session,
_ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
let upstream = self.0.select(b"", 256).unwrap();
// Set SNI
let peer = Box::new(HttpPeer::new(upstream, true, "one.one.one.one".to_string()));
Ok(peer)
}
async fn upstream_request_filter(
&self,
_session: &mut Session,
upstream_request: &mut RequestHeader,
_ctx: &mut Self::CTX,
) -> Result<()>
where
Self::CTX: Send + Sync,
{
upstream_request
.insert_header("Host", "one.one.one.one")
.unwrap();
Ok(())
}
async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result<bool>
where
Self::CTX: Send + Sync,
{
let appid = match self.get_request_appid(session) {
None => return Ok(false), // no client appid found, skip rate limiting
Some(addr) => addr,
};
// retrieve the current window requests
let curr_window_requests = RATE_LIMITER.observe(&appid, 1);
if curr_window_requests > MAX_REQ_PER_SEC {
// rate limited, return 429
let mut header = ResponseHeader::build(429, None).unwrap();
header
.insert_header("X-Rate-Limit-Limit", MAX_REQ_PER_SEC.to_string())
.unwrap();
header.insert_header("X-Rate-Limit-Remaining", "0").unwrap();
header.insert_header("X-Rate-Limit-Reset", "1").unwrap();
session.set_keepalive(None);
session
.write_response_header(Box::new(header), true)
.await?;
return Ok(true);
}
Ok(false)
}
}

View file

@ -0,0 +1,127 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use clap::Parser;
use pingora_core::modules::http::HttpModules;
use pingora_core::server::configuration::Opt;
use pingora_core::server::Server;
use pingora_core::upstreams::peer::HttpPeer;
use pingora_core::Result;
use pingora_http::RequestHeader;
use pingora_proxy::{ProxyHttp, Session};
/// This example shows how to build and import 3rd party modules
/// A simple ACL to check "Authorization: basic $credential" header
mod my_acl {
use super::*;
use pingora_core::modules::http::{HttpModule, HttpModuleBuilder, Module};
use pingora_error::{Error, ErrorType::HTTPStatus};
use std::any::Any;
// This is the struct for per request module context
struct MyAclCtx {
credential_header: String,
}
// Implement how the module would consume and/or modify request and/or response
#[async_trait]
impl HttpModule for MyAclCtx {
async fn request_header_filter(&mut self, req: &mut RequestHeader) -> Result<()> {
let Some(auth) = req.headers.get(http::header::AUTHORIZATION) else {
return Error::e_explain(HTTPStatus(403), "Auth failed, no auth header");
};
if auth.as_bytes() != self.credential_header.as_bytes() {
Error::e_explain(HTTPStatus(403), "Auth failed, credential mismatch")
} else {
Ok(())
}
}
// boilerplate code for all modules
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
// This is the singleton object which will be attached to the server
pub struct MyAcl {
pub credential: String,
}
impl HttpModuleBuilder for MyAcl {
// This function defines how to create each Ctx. This function is called when a new request
// arrives
fn init(&self) -> Module {
Box::new(MyAclCtx {
// Make it easier to compare header
// We could also store this value in MyAcl and use Arc to share it with every Ctx.
credential_header: format!("basic {}", self.credential),
})
}
}
}
pub struct MyProxy;
#[async_trait]
impl ProxyHttp for MyProxy {
type CTX = ();
fn new_ctx(&self) -> Self::CTX {}
// This function is only called once when the server starts
fn init_downstream_modules(&self, modules: &mut HttpModules) {
// Add the module to MyProxy
modules.add_module(Box::new(my_acl::MyAcl {
credential: "testcode".into(),
}))
}
async fn upstream_peer(
&self,
_session: &mut Session,
_ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
let peer = Box::new(HttpPeer::new(
("1.1.1.1", 443),
true,
"one.one.one.one".to_string(),
));
Ok(peer)
}
}
// RUST_LOG=INFO cargo run --example use_module
// curl 127.0.0.1:6193 -H "Host: one.one.one.one" -v
// curl 127.0.0.1:6193 -H "Host: one.one.one.one" -H "Authorization: basic testcode"
// curl 127.0.0.1:6193 -H "Host: one.one.one.one" -H "Authorization: basic wrong" -v
fn main() {
env_logger::init();
// read command line arguments
let opt = Opt::parse();
let mut my_server = Server::new(Some(opt)).unwrap();
my_server.bootstrap();
let mut my_proxy = pingora_proxy::http_proxy_service(&my_server.configuration, MyProxy);
my_proxy.add_tcp("0.0.0.0:6193");
my_server.add_service(my_proxy);
my_server.run_forever();
}

View file

@ -35,9 +35,6 @@
//!
//! See `examples/load_balancer.rs` for a detailed example.
// enable nightly feature async trait so that the docs are cleaner
#![cfg_attr(doc_async_trait, feature(async_fn_in_trait))]
use async_trait::async_trait;
use bytes::Bytes;
use futures::future::FutureExt;
@ -110,6 +107,14 @@ impl<SV> HttpProxy<SV> {
}
}
fn handle_init_modules(&mut self)
where
SV: ProxyHttp,
{
self.inner
.init_downstream_modules(&mut self.downstream_modules);
}
async fn handle_new_request(
&self,
mut downstream_session: Box<HttpSession>,
@ -222,7 +227,8 @@ impl<SV> HttpProxy<SV> {
}),
)
}
Err(e) => {
Err(mut e) => {
e.as_up();
let new_err = self.inner.fail_to_connect(session, &peer, ctx, e);
(false, Some(new_err.into_up()))
}
@ -739,7 +745,10 @@ use pingora_core::services::listening::Service;
/// Create a [Service] from the user implemented [ProxyHttp].
///
/// The returned [Service] can be hosted by a [pingora_core::server::Server] directly.
pub fn http_proxy_service<SV>(conf: &Arc<ServerConf>, inner: SV) -> Service<HttpProxy<SV>> {
pub fn http_proxy_service<SV>(conf: &Arc<ServerConf>, inner: SV) -> Service<HttpProxy<SV>>
where
SV: ProxyHttp,
{
http_proxy_service_with_name(conf, inner, "Pingora HTTP Proxy Service")
}
@ -750,11 +759,11 @@ pub fn http_proxy_service_with_name<SV>(
conf: &Arc<ServerConf>,
inner: SV,
name: &str,
) -> Service<HttpProxy<SV>> {
) -> Service<HttpProxy<SV>>
where
SV: ProxyHttp,
{
let mut proxy = HttpProxy::new(inner, conf.clone());
// Add disabled downstream compression module by default
proxy
.downstream_modules
.add_module(ResponseCompressionBuilder::enable(0));
proxy.handle_init_modules();
Service::new(name.to_string(), proxy)
}

View file

@ -165,7 +165,9 @@ impl<SV> HttpProxy<SV> {
} else {
break None;
}
} // else continue to serve stale
}
// else continue to serve stale
session.cache.set_stale_updating();
} else if session.cache.is_cache_lock_writer() {
// stale while revalidate logic for the writer
let will_serve_stale = session.cache.can_serve_stale_updating()
@ -182,6 +184,7 @@ impl<SV> HttpProxy<SV> {
new_app.process_subrequest(subrequest, sub_req_ctx).await;
});
// continue to serve stale for this request
session.cache.set_stale_updating();
} else {
// return to fetch from upstream
break None;

View file

@ -98,10 +98,7 @@ impl<SV> HttpProxy<SV> {
);
match ret {
Ok((_first, _second)) => {
client_session.respect_keepalive();
(true, true, None)
}
Ok((_first, _second)) => (true, true, None),
Err(e) => (false, false, Some(e)),
}
}
@ -549,6 +546,11 @@ impl<SV> HttpProxy<SV> {
// affected by the request_body_filter
let end_of_body = end_of_body || data.is_none();
session
.downstream_modules_ctx
.request_body_filter(&mut data, end_of_body)
.await?;
self.inner
.request_body_filter(session, &mut data, end_of_body, ctx)
.await?;

View file

@ -525,6 +525,11 @@ impl<SV> HttpProxy<SV> {
SV: ProxyHttp + Send + Sync,
SV::CTX: Send + Sync,
{
session
.downstream_modules_ctx
.request_body_filter(&mut data, end_of_body)
.await?;
self.inner
.request_body_filter(session, &mut data, end_of_body, ctx)
.await?;

View file

@ -40,6 +40,17 @@ pub trait ProxyHttp {
ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>>;
/// Set up downstream modules.
///
/// In this phase, users can add or configure [HttpModules] before the server starts up.
///
/// In the default implementation of this method, [ResponseCompressionBuilder] is added
/// and disabled.
fn init_downstream_modules(&self, modules: &mut HttpModules) {
// Add disabled downstream compression module by default
modules.add_module(ResponseCompressionBuilder::enable(0));
}
/// Handle the incoming request.
///
/// In this phase, users can parse, validate, rate limit, perform access control and/or

View file

@ -133,6 +133,60 @@ async fn test_ws_server_ends_conn() {
assert!(ws_stream.next().await.is_none());
}
#[tokio::test]
async fn test_download_timeout() {
init();
use hyper::body::HttpBody;
use tokio::time::sleep;
let client = hyper::Client::new();
let uri: hyper::Uri = "http://127.0.0.1:6147/download/".parse().unwrap();
let req = hyper::Request::builder()
.uri(uri)
.header("x-write-timeout", "1")
.body(hyper::Body::empty())
.unwrap();
let mut res = client.request(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let mut err = false;
sleep(Duration::from_secs(2)).await;
while let Some(chunk) = res.body_mut().data().await {
if chunk.is_err() {
err = true;
}
}
assert!(err);
}
#[tokio::test]
async fn test_download_timeout_min_rate() {
init();
use hyper::body::HttpBody;
use tokio::time::sleep;
let client = hyper::Client::new();
let uri: hyper::Uri = "http://127.0.0.1:6147/download/".parse().unwrap();
let req = hyper::Request::builder()
.uri(uri)
.header("x-write-timeout", "1")
.header("x-min-rate", "10000")
.body(hyper::Body::empty())
.unwrap();
let mut res = client.request(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let mut err = false;
sleep(Duration::from_secs(2)).await;
while let Some(chunk) = res.body_mut().data().await {
if chunk.is_err() {
err = true;
}
}
// no error as write timeout is overridden by min rate
assert!(!err);
}
mod test_cache {
use super::*;
use std::str::FromStr;
@ -1319,7 +1373,7 @@ mod test_cache {
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let headers = res.headers();
assert_eq!(headers["x-cache-status"], "stale");
assert_eq!(headers["x-cache-status"], "stale-updating");
assert_eq!(res.text().await.unwrap(), "hello world");
});
// sleep just a little to make sure the req above gets the cache lock
@ -1333,7 +1387,7 @@ mod test_cache {
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let headers = res.headers();
assert_eq!(headers["x-cache-status"], "stale");
assert_eq!(headers["x-cache-status"], "stale-updating");
assert_eq!(res.text().await.unwrap(), "hello world");
});
let task3 = tokio::spawn(async move {
@ -1345,7 +1399,7 @@ mod test_cache {
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let headers = res.headers();
assert_eq!(headers["x-cache-status"], "stale");
assert_eq!(headers["x-cache-status"], "stale-updating");
assert_eq!(res.text().await.unwrap(), "hello world");
});
@ -1382,7 +1436,7 @@ mod test_cache {
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let headers = res.headers();
assert_eq!(headers["x-cache-status"], "stale");
assert_eq!(headers["x-cache-status"], "stale-updating");
assert_eq!(res.text().await.unwrap(), "hello world");
// wait for the background request to finish

View file

@ -296,6 +296,15 @@ http {
}
}
location /download/ {
content_by_lua_block {
ngx.req.read_body()
local body = string.rep("A", 4194304)
ngx.header["Content-Length"] = #body
ngx.print(body)
}
}
location /tls_verify {
keepalive_timeout 0;
return 200;

View file

@ -39,6 +39,7 @@ use pingora_proxy::{ProxyHttp, Session};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
pub struct ExampleProxyHttps {}
@ -230,6 +231,17 @@ impl ProxyHttp for ExampleProxyHttp {
async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result<bool> {
let req = session.req_header();
let write_timeout = req
.headers
.get("x-write-timeout")
.and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok()));
let min_rate = req
.headers
.get("x-min-rate")
.and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok()));
let downstream_compression = req.headers.get("x-downstream-compression").is_some();
if !downstream_compression {
// enable upstream compression for all requests by default
@ -242,6 +254,13 @@ impl ProxyHttp for ExampleProxyHttp {
.adjust_level(0);
}
if let Some(min_rate) = min_rate {
session.set_min_send_rate(min_rate);
}
if let Some(write_timeout) = write_timeout {
session.set_write_timeout(Duration::from_secs(write_timeout));
}
Ok(false)
}
@ -454,6 +473,9 @@ impl ProxyHttp for ExampleProxyCache {
CachePhase::Hit => upstream_response.insert_header("x-cache-status", "hit")?,
CachePhase::Miss => upstream_response.insert_header("x-cache-status", "miss")?,
CachePhase::Stale => upstream_response.insert_header("x-cache-status", "stale")?,
CachePhase::StaleUpdating => {
upstream_response.insert_header("x-cache-status", "stale-updating")?
}
CachePhase::Expired => {
upstream_response.insert_header("x-cache-status", "expired")?
}

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-runtime"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"

View file

@ -1,6 +1,6 @@
[package]
name = "pingora-timeout"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
@ -24,7 +24,6 @@ tokio = { workspace = true, features = [
"sync",
] }
pin-project-lite = "0.2"
futures = "0.3"
once_cell = { workspace = true }
parking_lot = "0.12"
thread_local = "1.0"

View file

@ -50,7 +50,7 @@ fn check_clock_thread(tm: &Arc<TimerManager>) {
pub struct FastTimeout(Duration);
impl ToTimeout for FastTimeout {
fn timeout(&self) -> BoxFuture<'static, ()> {
fn timeout(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
Box::pin(TIMER_MANAGER.register_timer(self.0).poll())
}

View file

@ -39,7 +39,6 @@ pub mod timer;
pub use fast_timeout::fast_sleep as sleep;
pub use fast_timeout::fast_timeout as timeout;
use futures::future::BoxFuture;
use pin_project_lite::pin_project;
use std::future::Future;
use std::pin::Pin;
@ -50,7 +49,7 @@ use tokio::time::{sleep as tokio_sleep, Duration};
///
/// Users don't need to interact with this trait
pub trait ToTimeout {
fn timeout(&self) -> BoxFuture<'static, ()>;
fn timeout(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>>;
fn create(d: Duration) -> Self;
}
@ -60,7 +59,7 @@ pub trait ToTimeout {
pub struct TokioTimeout(Duration);
impl ToTimeout for TokioTimeout {
fn timeout(&self) -> BoxFuture<'static, ()> {
fn timeout(&self) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
Box::pin(tokio_sleep(self.0))
}
@ -100,7 +99,7 @@ pin_project! {
#[pin]
value: T,
#[pin]
delay: Option<BoxFuture<'static, ()>>,
delay: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
callback: F, // callback to create the timer
}
}

View file

@ -1,6 +1,6 @@
[package]
name = "pingora"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
@ -22,12 +22,12 @@ all-features = true
rustdoc-args = ["--cfg", "docsrs"]
[dependencies]
pingora-core = { version = "0.2.0", path = "../pingora-core", default-features = false }
pingora-http = { version = "0.2.0", path = "../pingora-http" }
pingora-timeout = { version = "0.2.0", path = "../pingora-timeout" }
pingora-load-balancing = { version = "0.2.0", path = "../pingora-load-balancing", optional = true, default-features = false }
pingora-proxy = { version = "0.2.0", path = "../pingora-proxy", optional = true, default-features = false }
pingora-cache = { version = "0.2.0", path = "../pingora-cache", optional = true, default-features = false }
pingora-core = { version = "0.3.0", path = "../pingora-core", default-features = false }
pingora-http = { version = "0.3.0", path = "../pingora-http" }
pingora-timeout = { version = "0.3.0", path = "../pingora-timeout" }
pingora-load-balancing = { version = "0.3.0", path = "../pingora-load-balancing", optional = true, default-features = false }
pingora-proxy = { version = "0.3.0", path = "../pingora-proxy", optional = true, default-features = false }
pingora-cache = { version = "0.3.0", path = "../pingora-cache", optional = true, default-features = false }
[dev-dependencies]
clap = { version = "3.2.25", features = ["derive"] }
@ -63,3 +63,4 @@ boringssl = [
proxy = ["pingora-proxy"]
lb = ["pingora-load-balancing", "proxy"]
cache = ["pingora-cache"]
time = []

View file

@ -1,6 +1,6 @@
[package]
name = "TinyUFO"
version = "0.2.0"
version = "0.3.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
edition = "2021"
license = "Apache-2.0"
@ -17,7 +17,7 @@ path = "src/lib.rs"
[dependencies]
ahash = { workspace = true }
flurry = "<0.5.0" # Try not to require Rust 1.71
flurry = "0.5"
parking_lot = "0"
crossbeam-queue = "0"
crossbeam-skiplist = "0"
@ -28,7 +28,7 @@ lru = "0"
zipf = "7"
moka = { version = "0", features = ["sync"] }
dhat = "0"
quick_cache = "0.4"
quick_cache = "0.6"
triomphe = "<=0.1.11" # 0.1.12 requires Rust 1.76
[[bench]]

View file

@ -36,23 +36,36 @@ impl Estimator {
fn optimal(items: usize) -> Self {
let (slots, hashes) = Self::optimal_paras(items);
Self::new(hashes, slots)
Self::new(hashes, slots, RandomState::new)
}
fn compact(items: usize) -> Self {
let (slots, hashes) = Self::optimal_paras(items / 100);
Self::new(hashes, slots)
Self::new(hashes, slots, RandomState::new)
}
/// Create a new `Estimator` with the given amount of hashes and columns (slots).
pub fn new(hashes: usize, slots: usize) -> Self {
#[cfg(test)]
fn seeded(items: usize) -> Self {
let (slots, hashes) = Self::optimal_paras(items);
Self::new(hashes, slots, || RandomState::with_seeds(2, 3, 4, 5))
}
#[cfg(test)]
fn seeded_compact(items: usize) -> Self {
let (slots, hashes) = Self::optimal_paras(items / 100);
Self::new(hashes, slots, || RandomState::with_seeds(2, 3, 4, 5))
}
/// Create a new `Estimator` with the given amount of hashes and columns (slots) using
/// the given random source.
pub fn new(hashes: usize, slots: usize, random: impl Fn() -> RandomState) -> Self {
let mut estimator = Vec::with_capacity(hashes);
for _ in 0..hashes {
let mut slot = Vec::with_capacity(slots);
for _ in 0..slots {
slot.push(AtomicU8::new(0));
}
estimator.push((slot.into_boxed_slice(), RandomState::new()));
estimator.push((slot.into_boxed_slice(), random()));
}
Estimator {
@ -161,6 +174,26 @@ impl TinyLfu {
window_limit: cache_size * 8,
}
}
#[cfg(test)]
pub fn new_seeded(cache_size: usize) -> Self {
Self {
estimator: Estimator::seeded(cache_size),
window_counter: Default::default(),
// 8x: just a heuristic to balance the memory usage and accuracy
window_limit: cache_size * 8,
}
}
#[cfg(test)]
pub fn new_compact_seeded(cache_size: usize) -> Self {
Self {
estimator: Estimator::seeded_compact(cache_size),
window_counter: Default::default(),
// 8x: just a heuristic to balance the memory usage and accuracy
window_limit: cache_size * 8,
}
}
}
#[cfg(test)]

View file

@ -473,7 +473,9 @@ mod tests {
#[test]
fn test_evict_from_small() {
let cache = TinyUfo::new(5, 5);
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
@ -496,7 +498,9 @@ mod tests {
#[test]
fn test_evict_from_small_to_main() {
let cache = TinyUfo::new(5, 5);
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
@ -510,20 +514,30 @@ mod tests {
assert_eq!(cache.peek_queue(2), Some(SMALL));
assert_eq!(cache.peek_queue(3), Some(SMALL));
let evicted = cache.put(4, 4, 1);
let evicted = cache.put(4, 4, 2);
assert_eq!(evicted.len(), 1);
assert_eq!(evicted[0].data, 2);
assert_eq!(evicted[0].weight, 2);
assert_eq!(cache.peek_queue(1), Some(MAIN));
// 2 is evicted because 1 is in main
assert_eq!(cache.peek_queue(2), None);
assert_eq!(cache.peek_queue(3), Some(SMALL));
assert_eq!(cache.peek_queue(4), Some(SMALL));
// either 2, 3, or 4 was evicted. Check evicted for which.
let mut remaining = vec![2, 3, 4];
remaining.remove(
remaining
.iter()
.position(|x| *x == evicted[0].data)
.unwrap(),
);
assert_eq!(cache.peek_queue(evicted[0].key), None);
for k in remaining {
assert_eq!(cache.peek_queue(k), Some(SMALL));
}
}
#[test]
fn test_evict_reentry() {
let cache = TinyUfo::new(5, 5);
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
@ -555,7 +569,9 @@ mod tests {
#[test]
fn test_evict_entry_denied() {
let cache = TinyUfo::new(5, 5);
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
@ -583,7 +599,9 @@ mod tests {
#[test]
fn test_force_put() {
let cache = TinyUfo::new(5, 5);
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
@ -612,7 +630,9 @@ mod tests {
#[test]
fn test_evict_from_main() {
let cache = TinyUfo::new(5, 5);
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
@ -649,7 +669,9 @@ mod tests {
#[test]
fn test_evict_from_small_compact() {
let cache = TinyUfo::new(5, 5);
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_compact_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
@ -672,7 +694,9 @@ mod tests {
#[test]
fn test_evict_from_small_to_main_compact() {
let cache = TinyUfo::new(5, 5);
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_compact_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);