diff --git a/.bleep b/.bleep index f71cb64..95ee21d 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -bea67a70dff1b8a8a04d46b0c322e8fac1120d0b \ No newline at end of file +b5675546711d1fc8bc1a0aa28e4586e46d560024 \ No newline at end of file diff --git a/pingora-cache/src/lib.rs b/pingora-cache/src/lib.rs index f1f8bd2..8f88299 100644 --- a/pingora-cache/src/lib.rs +++ b/pingora-cache/src/lib.rs @@ -612,7 +612,15 @@ impl HttpCache { // Downstream read and upstream write can be decoupled let body_reader = inner .storage - .lookup(key, &inner.traces.get_miss_span()) + .lookup_streaming_write( + key, + inner + .miss_handler + .as_ref() + .expect("miss handler already set") + .streaming_write_tag(), + &inner.traces.get_miss_span(), + ) .await?; if let Some((_meta, body_reader)) = body_reader { diff --git a/pingora-cache/src/max_file_size.rs b/pingora-cache/src/max_file_size.rs index 7d812f2..72caefa 100644 --- a/pingora-cache/src/max_file_size.rs +++ b/pingora-cache/src/max_file_size.rs @@ -72,4 +72,8 @@ impl HandleMiss for MaxFileSizeMissHandler { async fn finish(self: Box) -> pingora_error::Result { self.inner.finish().await } + + fn streaming_write_tag(&self) -> Option<&[u8]> { + self.inner.streaming_write_tag() + } } diff --git a/pingora-cache/src/memory.rs b/pingora-cache/src/memory.rs index dec8f1a..9283110 100644 --- a/pingora-cache/src/memory.rs +++ b/pingora-cache/src/memory.rs @@ -20,7 +20,7 @@ use super::*; use crate::key::CompactCacheKey; -use crate::storage::{HandleHit, HandleMiss}; +use crate::storage::{streaming_write::U64WriteId, HandleHit, HandleMiss}; use crate::trace::SpanHandle; use async_trait::async_trait; @@ -29,6 +29,7 @@ use parking_lot::RwLock; use pingora_error::*; use std::any::Any; use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use tokio::sync::watch; @@ -68,7 +69,8 @@ impl TempObject { /// For testing only, not for production use. pub struct MemCache { pub(crate) cached: Arc>>, - pub(crate) temp: Arc>>, + pub(crate) temp: Arc>>>, + pub(crate) last_temp_id: AtomicU64, } impl MemCache { @@ -77,6 +79,7 @@ impl MemCache { MemCache { cached: Arc::new(RwLock::new(HashMap::new())), temp: Arc::new(RwLock::new(HashMap::new())), + last_temp_id: AtomicU64::new(0), } } } @@ -213,8 +216,11 @@ pub struct MemMissHandler { bytes_written: Arc>, // these are used only in finish() to data from temp to cache key: String, + temp_id: U64WriteId, + // key -> cache object cache: Arc>>, - temp: Arc>>, + // key -> (temp writer id -> temp object) to support concurrent writers + temp: Arc>>>, } #[async_trait] @@ -237,20 +243,48 @@ impl HandleMiss for MemMissHandler { async fn finish(self: Box) -> Result { // safe, the temp object is inserted when the miss handler is created - let cache_object = self.temp.read().get(&self.key).unwrap().make_cache_object(); + let cache_object = self + .temp + .read() + .get(&self.key) + .unwrap() + .get(&self.temp_id.into()) + .unwrap() + .make_cache_object(); let size = cache_object.body.len(); // FIXME: this just body size, also track meta size self.cache.write().insert(self.key.clone(), cache_object); - self.temp.write().remove(&self.key); + self.temp + .write() + .get_mut(&self.key) + .and_then(|map| map.remove(&self.temp_id.into())); Ok(size) } + + fn streaming_write_tag(&self) -> Option<&[u8]> { + Some(self.temp_id.as_bytes()) + } } impl Drop for MemMissHandler { fn drop(&mut self) { - self.temp.write().remove(&self.key); + self.temp + .write() + .get_mut(&self.key) + .and_then(|map| map.remove(&self.temp_id.into())); } } +fn hit_from_temp_obj(temp_obj: &TempObject) -> Result> { + let meta = CacheMeta::deserialize(&temp_obj.meta.0, &temp_obj.meta.1)?; + let partial = PartialHit { + body: temp_obj.body.clone(), + bytes_written: temp_obj.bytes_written.subscribe(), + bytes_read: 0, + }; + let hit_handler = MemHitHandler::Partial(partial); + Ok(Some((meta, Box::new(hit_handler)))) +} + #[async_trait] impl Storage for MemCache { async fn lookup( @@ -261,15 +295,14 @@ impl Storage for MemCache { let hash = key.combined(); // always prefer partial read otherwise fresh asset will not be visible on expired asset // until it is fully updated - if let Some(temp_obj) = self.temp.read().get(&hash) { - let meta = CacheMeta::deserialize(&temp_obj.meta.0, &temp_obj.meta.1)?; - let partial = PartialHit { - body: temp_obj.body.clone(), - bytes_written: temp_obj.bytes_written.subscribe(), - bytes_read: 0, - }; - let hit_handler = MemHitHandler::Partial(partial); - Ok(Some((meta, Box::new(hit_handler)))) + // no preference on which partial read we get (if there are multiple writers) + if let Some((_, temp_obj)) = self + .temp + .read() + .get(&hash) + .and_then(|map| map.iter().next()) + { + hit_from_temp_obj(temp_obj) } else if let Some(obj) = self.cached.read().get(&hash) { let meta = CacheMeta::deserialize(&obj.meta.0, &obj.meta.1)?; let hit_handler = CompleteHit { @@ -285,24 +318,49 @@ impl Storage for MemCache { } } + async fn lookup_streaming_write( + &'static self, + key: &CacheKey, + streaming_write_tag: Option<&[u8]>, + _trace: &SpanHandle, + ) -> Result> { + let hash = key.combined(); + let write_tag: U64WriteId = streaming_write_tag + .expect("tag must be set during streaming write") + .try_into() + .expect("tag must be correct length"); + hit_from_temp_obj( + self.temp + .read() + .get(&hash) + .and_then(|map| map.get(&write_tag.into())) + .expect("must have partial write in progress"), + ) + } + async fn get_miss_handler( &'static self, key: &CacheKey, meta: &CacheMeta, _trace: &SpanHandle, ) -> Result { - // TODO: support multiple concurrent writes or panic if the is already a writer let hash = key.combined(); let meta = meta.serialize()?; let temp_obj = TempObject::new(meta); + let temp_id = self.last_temp_id.fetch_add(1, Ordering::Relaxed); let miss_handler = MemMissHandler { body: temp_obj.body.clone(), bytes_written: temp_obj.bytes_written.clone(), key: hash.clone(), cache: self.cached.clone(), temp: self.temp.clone(), + temp_id: temp_id.into(), }; - self.temp.write().insert(hash, temp_obj); + self.temp + .write() + .entry(hash) + .or_default() + .insert(miss_handler.temp_id.into(), temp_obj); Ok(Box::new(miss_handler)) } @@ -526,7 +584,9 @@ mod test { ); let temp_obj = TempObject::new(meta); - cache.temp.write().insert(hash.clone(), temp_obj); + let mut map = HashMap::new(); + map.insert(0, temp_obj); + cache.temp.write().insert(hash.clone(), map); assert!(cache.temp.read().contains_key(&hash)); diff --git a/pingora-cache/src/storage.rs b/pingora-cache/src/storage.rs index 0d4e104..7cf1e7b 100644 --- a/pingora-cache/src/storage.rs +++ b/pingora-cache/src/storage.rs @@ -36,13 +36,34 @@ pub enum PurgeType { pub trait Storage { // TODO: shouldn't have to be static - /// Lookup the storage for the given [CacheKey] + /// Lookup the storage for the given [CacheKey]. async fn lookup( &'static self, key: &CacheKey, trace: &SpanHandle, ) -> Result>; + /// Lookup the storage for the given [CacheKey] using a streaming write tag. + /// + /// When streaming partial writes is supported, the request that initiates the write will also + /// pass an optional `streaming_write_tag` so that the storage may try to find the associated + /// [HitHandler], for the same ongoing write. + /// + /// Therefore, when the write tag is set, the storage implementation should either return a + /// [HitHandler] that can be matched to that tag, or none at all. Otherwise when the storage + /// supports concurrent streaming writes for the same key, the calling request may receive a + /// different body from the one it expected. + /// + /// By default this defers to the standard `Storage::lookup` implementation. + async fn lookup_streaming_write( + &'static self, + key: &CacheKey, + _streaming_write_tag: Option<&[u8]>, + trace: &SpanHandle, + ) -> Result> { + self.lookup(key, trace).await + } + /// Write the given [CacheMeta] to the storage. Return [MissHandler] to write the body later. async fn get_miss_handler( &'static self, @@ -130,7 +151,87 @@ pub trait HandleMiss { async fn finish( self: Box, // because self is always used as a trait object ) -> Result; + + /// Return a streaming write tag recognized by the underlying [`Storage`]. + /// + /// This is an arbitrary data identifier that is used to associate this miss handler's current + /// write with a hit handler for the same write. This identifier will be compared by the + /// storage during `lookup_streaming_write`. + // This write tag is essentially an borrowed data blob of bytes retrieved from the miss handler + // and passed to storage, which means it can support strings or small data types, e.g. bytes + // represented by a u64. + // The downside with the current API is that such a data blob must be owned by the miss handler + // and stored in a way that permits retrieval as a byte slice (not computed on the fly). + // But most use cases likely only require a simple integer and may not like the overhead of a + // Vec/String allocation or even a Cow, though such data types can also be used here. + fn streaming_write_tag(&self) -> Option<&[u8]> { + None + } } /// Miss Handler pub type MissHandler = Box<(dyn HandleMiss + Sync + Send)>; + +pub mod streaming_write { + /// Portable u64 (sized) write id convenience type for use with streaming writes. + /// + /// Often an integer value is sufficient for a streaming write tag. This convenience type enables + /// storing such a value and functions for consistent conversion between byte sequence data types. + #[derive(Debug, Clone, Copy)] + pub struct U64WriteId([u8; 8]); + + impl U64WriteId { + pub fn as_bytes(&self) -> &[u8] { + &self.0[..] + } + } + + impl From for U64WriteId { + fn from(value: u64) -> U64WriteId { + U64WriteId(value.to_be_bytes()) + } + } + impl From for u64 { + fn from(value: U64WriteId) -> u64 { + u64::from_be_bytes(value.0) + } + } + impl TryFrom<&[u8]> for U64WriteId { + type Error = std::array::TryFromSliceError; + + fn try_from(value: &[u8]) -> std::result::Result { + Ok(U64WriteId(value.try_into()?)) + } + } + + /// Portable u32 (sized) write id convenience type for use with streaming writes. + /// + /// Often an integer value is sufficient for a streaming write tag. This convenience type enables + /// storing such a value and functions for consistent conversion between byte sequence data types. + #[derive(Debug, Clone, Copy)] + pub struct U32WriteId([u8; 4]); + + impl U32WriteId { + pub fn as_bytes(&self) -> &[u8] { + &self.0[..] + } + } + + impl From for U32WriteId { + fn from(value: u32) -> U32WriteId { + U32WriteId(value.to_be_bytes()) + } + } + impl From for u32 { + fn from(value: U32WriteId) -> u32 { + u32::from_be_bytes(value.0) + } + } + impl TryFrom<&[u8]> for U32WriteId { + type Error = std::array::TryFromSliceError; + + fn try_from(value: &[u8]) -> std::result::Result { + Ok(U32WriteId(value.try_into()?)) + } + } +} diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index 02174c1..42319a3 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -1516,6 +1516,43 @@ mod test_cache { task3.await.unwrap(); } + #[tokio::test] + async fn test_cache_streaming_multiple_writers() { + // multiple streaming writers don't conflict + init(); + let url = "http://127.0.0.1:6148/slow_body/test_cache_streaming_multiple_writers.txt"; + let task1 = tokio::spawn(async move { + let res = reqwest::Client::new() + .get(url) + .header("x-set-hello", "everyone") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!(headers["x-cache-status"], "miss"); + assert_eq!(res.text().await.unwrap(), "hello everyone!"); + }); + + let task2 = tokio::spawn(async move { + let res = reqwest::Client::new() + .get(url) + // don't allow using the other streaming write's result + .header("x-force-expire", "1") + .header("x-set-hello", "todo el mundo") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let headers = res.headers(); + assert_eq!(headers["x-cache-status"], "miss"); + assert_eq!(res.text().await.unwrap(), "hello todo el mundo!"); + }); + + task1.await.unwrap(); + task2.await.unwrap(); + } + #[tokio::test] async fn test_range_request() { init(); diff --git a/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf b/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf index a41a743..6d5abd7 100644 --- a/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf +++ b/pingora-proxy/tests/utils/conf/origin/conf/nginx.conf @@ -408,12 +408,13 @@ http { location /slow_body { content_by_lua_block { local sleep_sec = tonumber(ngx.var.http_x_set_sleep) or 1 + local hello_to = ngx.var.http_x_set_hello or "world" ngx.flush() ngx.sleep(sleep_sec) ngx.print("hello ") ngx.flush() ngx.sleep(sleep_sec) - ngx.print("world") + ngx.print(hello_to) ngx.sleep(sleep_sec) ngx.print("!") } diff --git a/pingora-proxy/tests/utils/server_utils.rs b/pingora-proxy/tests/utils/server_utils.rs index 7db2622..f90a27e 100644 --- a/pingora-proxy/tests/utils/server_utils.rs +++ b/pingora-proxy/tests/utils/server_utils.rs @@ -399,6 +399,19 @@ impl ProxyHttp for ExampleProxyCache { Ok(()) } + async fn cache_hit_filter( + &self, + session: &Session, + _meta: &CacheMeta, + _ctx: &mut Self::CTX, + ) -> Result { + // allow test header to control force expiry + if session.get_header_bytes("x-force-expire") != b"" { + return Ok(true); + } + Ok(false) + } + fn cache_vary_filter( &self, meta: &CacheMeta,