Add callback function to Backends update() to address atomicity issue when building selector

This commit is contained in:
Andrew Hauck 2024-07-20 14:50:00 -07:00 committed by Yuchen Wu
parent 7c122e7f36
commit a51874039f
2 changed files with 100 additions and 21 deletions

2
.bleep
View file

@ -1 +1 @@
78a170341a0fb030b8bcb2afe84afb268cdc5b2d
9fdf48d67b78675c989f51ec18829a81fe6976ef

View file

@ -130,8 +130,17 @@ impl Backends {
self.health_check = Some(hc.into())
}
/// Return true when the new is different from the current set of backends
fn do_update(&self, new_backends: BTreeSet<Backend>, enablement: HashMap<u64, bool>) -> bool {
/// Updates backends when the new is different from the current set,
/// the callback will be invoked when the new set of backend is different
/// from the current one so that the caller can update the selector accordingly.
fn do_update<F>(
&self,
new_backends: BTreeSet<Backend>,
enablement: HashMap<u64, bool>,
callback: F,
) where
F: Fn(Arc<BTreeSet<Backend>>),
{
if (**self.backends.load()) != new_backends {
let old_health = self.health.load();
let mut health = HashMap::with_capacity(new_backends.len());
@ -147,10 +156,14 @@ impl Backends {
health.insert(hash_key, backend_health);
}
// TODO: put backend and health under 1 ArcSwap so that this update is atomic
self.backends.store(Arc::new(new_backends));
// TODO: put this all under 1 ArcSwap so the update is atomic
// It's important the `callback()` executes first since computing selector backends might
// be expensive. For example, if a caller checks `backends` to see if any are available
// they may encounter false positives if the selector isn't ready yet.
let new_backends = Arc::new(new_backends);
callback(new_backends.clone());
self.backends.store(new_backends);
self.health.store(Arc::new(health));
true
} else {
// no backend change, just check enablement
for (hash_key, backend_enabled) in enablement.iter() {
@ -160,7 +173,6 @@ impl Backends {
backend_health.enable(*backend_enabled);
}
}
false
}
}
@ -199,12 +211,15 @@ impl Backends {
/// Call the service discovery method to update the collection of backends.
///
/// Return `true` when the new collection is different from the current set of backends.
/// This return value is useful to tell the caller when to rebuild things that are expensive to
/// update, such as consistent hashing rings.
pub async fn update(&self) -> Result<bool> {
/// The callback will be invoked when the new set of backend is different
/// from the current one so that the caller can update the selector accordingly.
pub async fn update<F>(&self, callback: F) -> Result<()>
where
F: Fn(Arc<BTreeSet<Backend>>),
{
let (new_backends, enablement) = self.discovery.discover().await?;
Ok(self.do_update(new_backends, enablement))
self.do_update(new_backends, enablement, callback);
Ok(())
}
/// Run health check on all backends if it is set.
@ -320,11 +335,9 @@ where
/// This function will be called every `update_frequency` if this [LoadBalancer] instance
/// is running as a background service.
pub async fn update(&self) -> Result<()> {
if self.backends.update().await? {
self.selector
.store(Arc::new(S::build(&self.backends.get_backend())))
}
Ok(())
self.backends
.update(|backends| self.selector.store(Arc::new(S::build(&backends))))
.await
}
/// Return the first healthy [Backend] according to the selection algorithm and the
@ -378,6 +391,8 @@ where
#[cfg(test)]
mod test {
use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
use super::*;
use async_trait::async_trait;
@ -408,10 +423,20 @@ mod test {
backends.set_health_check(check);
// true: new backend discovered
assert!(backends.update().await.unwrap());
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(updated.load(Relaxed));
// false: no new backend discovered
assert!(!backends.update().await.unwrap());
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(!updated.load(Relaxed));
backends.run_health_check(false).await;
@ -449,7 +474,14 @@ mod test {
let discovery = TestDiscovery(discovery);
let backends = Backends::new(Box::new(discovery));
assert!(backends.update().await.unwrap());
// true: new backend discovered
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(updated.load(Relaxed));
let backend = backends.get_backend();
assert!(backend.contains(&good1));
@ -476,7 +508,12 @@ mod test {
backends.set_health_check(check);
// true: new backend discovered
assert!(backends.update().await.unwrap());
let updated = AtomicBool::new(false);
backends
.update(|_| updated.store(true, Relaxed))
.await
.unwrap();
assert!(updated.load(Relaxed));
backends.run_health_check(true).await;
@ -484,4 +521,46 @@ mod test {
assert!(backends.ready(&good2));
assert!(!backends.ready(&bad));
}
mod thread_safety {
use super::*;
struct MockDiscovery {
expected: usize,
}
#[async_trait]
impl ServiceDiscovery for MockDiscovery {
async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> {
let mut d = BTreeSet::new();
let mut m = HashMap::with_capacity(self.expected);
for i in 0..self.expected {
let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap();
m.insert(i as u64, true);
d.insert(b);
}
Ok((d, m))
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_consistency() {
let expected = 3000;
let discovery = MockDiscovery { expected };
let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends(
Backends::new(Box::new(discovery)),
));
let lb2 = lb.clone();
tokio::spawn(async move {
assert!(lb2.update().await.is_ok());
});
let mut backend_count = 0;
while backend_count == 0 {
let backends = lb.backends();
backend_count = backends.backends.load_full().len();
}
assert_eq!(backend_count, expected);
assert!(lb.select_with(b"test", 1, |_, _| true).is_some());
}
}
}