Release Pingora version 0.1.0

Co-authored-by: Andrew Hauck <ahauck@cloudflare.com>
Co-authored-by: Edward Wang <ewang@cloudflare.com>
This commit is contained in:
Yuchen Wu 2024-02-27 20:25:44 -08:00
parent 0bca116c10
commit 8797329225
279 changed files with 48111 additions and 18 deletions

51
.github/CONTRIBUTING.md vendored Normal file
View file

@ -0,0 +1,51 @@
# Contributing
Welcome to Pingora! Before you make a contribution, be it a bug report, documentation improvement,
pull request (PR), etc., please read and follow these guidelines.
## Start with filing an issue
More often than not, **start by filing an issue on GitHub**. If you have a bug report or feature
request, open a GitHub issue. Non-trivial PRs will also require a GitHub issue. The issue provides
us with a space to discuss proposed changes with you and the community.
Having a discussion via GitHub issue upfront is the best way to ensure your contribution lands in
Pingora. We don't want you to spend your time making a PR, only to find that we won't accept it on
a design basis. For example, we may find that your proposed feature works better as a third-party
module built on top of or for use with Pingora and encourage you to pursue that direction instead.
**You do not need to file an issue for small fixes.** What counts as a "small" or trivial fix is a
judgment call, so here's a few examples to clarify:
- fixing a typo
- refactoring a bit of code
- most documentation or comment edits
Still, _sometimes_ we may review your PR and ask you to file an issue if we expect there are larger
design decisions to be made.
## Making a PR
After you've filed an issue, you can make your PR referencing that issue number. Once you open your
PR, it will be labelled _needs review_. A maintainer will review your PR as soon as they can. The
reviewer may ask for changes - they will mark the PR as _changes requested_ and _work in progress_
and will give you details about the requested changes. Feel free to ask lots of questions! The
maintainers are there to help you.
### Caveats
Currently, internal contributions will take priority. Today Pingora is being maintained by
Cloudflare's Content Delivery team, and internal Cloudflare proxy services are a primary user of
Pingora. We value the community's work on Pingora, but the reality is that our team has a limited
amount of resources and time. We can't promise we will review or address all PRs or issues in a
timely manner.
## Conduct
Pingora and Cloudflare OpenSource generally follows the [Contributor Covenant Code of Conduct].
Violating the CoC could result in a warning or a ban to Pingora or any and all repositories in the Cloudflare organization.
[Contributor Covenant Code of Conduct]: https://github.com/cloudflare/.github/blob/26b37ca2ba7ab3d91050ead9f2c0e30674d3b91e/CODE_OF_CONDUCT.md
## Contact
If you have any questions, please reach out to [opensource@cloudflare.com](mailto:opensource@cloudflare.com).

37
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View file

@ -0,0 +1,37 @@
---
name: Bug Report
about: Report an issue to help us improve
title: ''
labels: ''
assignees: ''
---
## Describe the bug
A clear and concise description of what the bug is.
## Pingora info
Please include the following information about your environment:
**Pingora version**: release number of commit hash
**Rust version**: i.e. `cargo --version`
**Operating system version**: e.g. Ubuntu 22.04, Debian 12.4
## Steps to reproduce
Please provide step-by-step instructions to reproduce the issue. Include any relevant code
snippets.
## Expected results
What were you expecting to happen?
## Observed results
What actually happened?
## Additional context
What other information would you like to provide? e.g. screenshots, how you're working around the
issue, or other clues you think could be helpful to identify the root cause.

View file

@ -0,0 +1,27 @@
---
name: Feature request
about: Propose a new feature
title: ''
labels: ''
assignees: ''
---
## What is the problem your feature solves, or the need it fulfills?
A clear and concise description of why this feature should be added. What is the problem? Who is
this for?
## Describe the solution you'd like
What do you propose to resolve the problem or fulfill the need above? How would you like it to
work?
## Describe alternatives you've considered
What other solutions, features, or workarounds have you considered that might also solve the issue?
What are the tradeoffs for these alternatives compared to what you're proposing?
## Additional context
This could include references to documentation or papers, prior art, screenshots, or benchmark
results.

9
.gitignore vendored
View file

@ -1,6 +1,7 @@
**/target
Cargo.lock
/target
**/*.rs.bk
**/Cargo.lock
**/dhat-heap.json
dhat-heap.json
.vscode
.cover
.idea
.cover

0
.rustfmt.toml Normal file
View file

View file

@ -1,4 +1,37 @@
[workspace]
resolver = "2"
members = [
"pingora",
"pingora-core",
"pingora-pool",
"pingora-error",
"pingora-limits",
]
"pingora-timeout",
"pingora-header-serde",
"pingora-proxy",
"pingora-cache",
"pingora-http",
"pingora-lru",
"pingora-openssl",
"pingora-boringssl",
"pingora-runtime",
"pingora-ketama",
"pingora-load-balancing",
"pingora-memory-cache",
"tinyufo",
]
[workspace.dependencies]
tokio = "1"
async-trait = "0.1.42"
httparse = "1"
bytes = "1.0"
http = "1.0.0"
log = "0.4"
h2 = ">=0.4.2"
once_cell = "1"
lru = "0"
ahash = ">=0.8.9"
[profile.bench]
debug = true

View file

@ -1,5 +1,65 @@
# Pingora
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
![Pingora banner image](./docs/assets/pingora_banner.png)
A library for building fast, reliable and evolvable network services.
## What is Pingora
Pingora is a Rust framework to [build fast, reliable and programmable networked systems](https://blog.cloudflare.com/pingora-open-source).
Pingora is battle tested as it has been serving more than 40 million Internet requests per second for [more than a few years](https://blog.cloudflare.com/how-we-built-pingora-the-proxy-that-connects-cloudflare-to-the-internet).
## Feature highlights
* Async Rust: fast and reliable
* HTTP 1/2 end to end proxy
* TLS over OpenSSL or BoringSSL
* gRPC and websocket proxying
* Graceful reload
* Customizable load balancing and failover strategies
* Support for a variety of observability tools
## Reasons to use Pingora
* **Security** is your top priority: Pingora is a more memory safe alternative for services that are written in C/C++.
* Your service is **performance-sensitive**: Pingora is fast and efficient.
* Your service requires extensive **customization**: The APIs Pingora proxy framework provides are highly programmable.
# Getting started
See our [quick starting guide](./docs/quick_start.md) to see how easy it is to build a load balancer.
Our [user guide](./docs/user_guide/index.md) covers more topics such as how to configure and run Pingora servers, as well as how to build custom HTTP server and proxy logic on top of Pingora's framework.
API docs are also available for all the crates.
# Notable crates in this workspace
* Pingora: the "public facing" crate to build to build networked systems and proxies.
* Pingora-core: this crates defined the protocols, functionalities and basic traits.
* Pingora-proxy: the logic and APIs to build HTTP proxies.
* Pingora-error: the common error type used across Pingora crates
* Pingora-http: the HTTP header definitions and APIs
* Pingora-openssl & pingora-boringssl: SSL related extensions and APIs
* Pingora-ketama: the [Ketama](https://github.com/RJ/ketama) consistent algorithm
* Pingora-limits: efficient counting algorithms
* Pingora-load-balancing: load balancing algorithm extensions for pingora proxy
* Pingora-memory-cache: Async in-memory caching with cache lock to prevent cache stampede.
* Pingora-timeout: A more efficient async timer system.
* TinyUfo: The caching algorithm behind pingora-memory-cache.
# System requirements
## Systems
Linux is our tier 1 environment and main focus.
We will try our best for most code to compile for Unix environments. This is for developers and users to have an easier time developing with Pingora in Unix-like environments like macOS (though some features might be missing)
Both x86_64 and aarch64 architectures will be supported.
## Rust version
Pingora keeps a rolling MSRV (minimum supported Rust version) policy of 6 months. This means we will accept PRs that upgrade the MSRV as long as the new Rust version used is at least 6 months old.
Our current MSRV is 1.72.
# Contributing
Please see our [contribution guidelines](./.github/CONTRIBUTING.md).
# License
This project is Licensed under [Apache License, Version 2.0](./LICENSE).

14
docs/README.md Normal file
View file

@ -0,0 +1,14 @@
# Pingora User Manual
## Quick Start
In this section we show you how to build a barebones load balancer.
[Read the quick start here.](quick_start.md)
## User Guide
Covers how to configure and run Pingora servers, as well as how to build custom HTTP server and proxy logic on top of Pingora's framework.
[Read the user guide here.](user_guide/index.md)
## API Reference
TBD

Binary file not shown.

After

Width:  |  Height:  |  Size: 235 KiB

324
docs/quick_start.md Normal file
View file

@ -0,0 +1,324 @@
# Quick Start: load balancer
## Introduction
This quick start shows how to build a bare-bones load balancer using pingora and pingora-proxy.
The goal of the load balancer is for every incoming HTTP request, select one of the two backends: https://1.1.1.1 and https://1.0.0.1 in a round-robin fashion.
## Build a basic load balancer
Create a new cargo project for our load balancer. Let's call it `load_balancer`
```
cargo new load_balancer
```
### Include the Pingora Crate and Basic Dependencies
In your project's `cargo.toml` file add the following to your dependencies
```
async-trait="0.1"
pingora = { version = "0.1", features = [ "lb" ] }
```
### Create a pingora server
First, let's create a pingora server. A pingora `Server` is a process which can host one or many
services. The pingora `Server` takes care of configuration and CLI argument parsing, daemonization,
signal handling, and graceful restart or shutdown.
The preferred usage is to initialize the `Server` in the `main()` function and
use `run_forever()` to spawn all the runtime threads and block the main thread until the server is
ready to exit.
```rust
use async_trait::async_trait;
use pingora::prelude::*;
use std::sync::Arc;
fn main() {
let mut my_server = Server::new(None).unwrap();
my_server.bootstrap();
my_server.run_forever();
}
```
This will compile and run, but it doesn't do anything interesting.
### Create a load balancer proxy
Next let's create a load balancer. Our load balancer holds a static list of upstream IPs. The `pingora-load-balancing` crate already provides the `LoadBalancer` struct with common selection algorithms such as round robin and hashing. So lets just use it. If the use case requires more sophisticated or customized server selection logic, users can simply implement it themselves in this function.
```rust
pub struct LB(Arc<LoadBalancer<RoundRobin>>);
```
In order to make the server a proxy, we need to implement the `ProxyHttp` trait for it.
Any object that implements the `ProxyHttp` trait essentially defines how a request is handled in
the proxy. The only required method in the `ProxyHttp` trait is `upstream_peer()` which returns
the address where the request should be proxied to.
In the body of the `upstream_peer()`, let's use the `select()` method for the `LoadBalancer` to round-robin across the upstream IPs. In this example we use HTTPS to connect to the backends, so we also need to specify to `use_tls` and set the SNI when constructing our [`Peer`](peer.md) object.
```rust
#[async_trait]
impl ProxyHttp for LB {
/// For this small example, we don't need context storage
type CTX = ();
fn new_ctx(&self) -> () {
()
}
async fn upstream_peer(&self, _session: &mut Session, _ctx: &mut ()) -> Result<Box<HttpPeer>> {
let upstream = self.0
.select(b"", 256) // hash doesn't matter for round robin
.unwrap();
println!("upstream peer is: {:upstream?}");
// Set SNI to one.one.one.one
let peer = Box::new(HttpPeer::new(upstream, true, "one.one.one.one".to_string()));
Ok(peer)
}
}
```
In order for the 1.1.1.1 backends to accept our requests, a host header must be present. Adding this header
can be done by the `upstream_request_filter()` callback which modifies the request header after
the connection to the backends are established and before the request header is sent.
```rust
impl ProxyHttp for LB {
// ...
async fn upstream_request_filter(
&self,
_session: &mut Session,
upstream_request: &mut RequestHeader,
_ctx: &mut Self::CTX,
) -> Result<()> {
upstream_request.insert_header("Host", "one.one.one.one").unwrap();
Ok(())
}
}
```
### Create a pingora-proxy service
Next, let's create a proxy service that follows the instructions of the load balancer above.
A pingora `Service` listens to one or multiple (TCP or Unix domain socket) endpoints. When a new connection is established
the `Service` hands the connection over to its "application." `pingora-proxy` is such an application
which proxies the HTTP request to the given backend as configured above.
In the example below, we create a `LB` instance with two backends `1.1.1.1:443` and `1.0.0.1:443`.
We put that `LB` instance to a proxy `Service` via the `http_proxy_service()` call and then tell our
`Server` to host that proxy `Service`.
```rust
fn main() {
let mut my_server = Server::new(None).unwrap();
my_server.bootstrap();
let upstreams =
LoadBalancer::try_from_iter(["1.1.1.1:443", "1.0.0.1:443"]).unwrap();
let mut lb = http_proxy_service(&my_server.configuration, LB(Arc::new(upstreams)));
lb.add_tcp("0.0.0.0:6188");
my_server.add_service(lb);
my_server.run_forever();
}
```
### Run it
Now that we have added the load balancer to the service, we can run our new
project with
```cargo run```
To test it, simply send the server a few requests with the command:
```
curl 127.0.0.1:6188 -svo /dev/null
```
You can also navigate your browser to [http://localhost:6188](http://localhost:6188)
The following output shows that the load balancer is doing its job to balance across the two backends:
```
upstream peer is: Backend { addr: Inet(1.0.0.1:443), weight: 1 }
upstream peer is: Backend { addr: Inet(1.1.1.1:443), weight: 1 }
upstream peer is: Backend { addr: Inet(1.0.0.1:443), weight: 1 }
upstream peer is: Backend { addr: Inet(1.1.1.1:443), weight: 1 }
upstream peer is: Backend { addr: Inet(1.0.0.1:443), weight: 1 }
...
```
Well done! At this point you have a functional load balancer. It is a _very_
basic load balancer though, so the next section will walk you through how to
make it more robust with some built-in pingora tooling.
## Add functionality
Pingora provides several helpful features that can be enabled and configured
with just a few lines of code. These range from simple peer health checks to
the ability to seamlessly update running binary with zero service interruptions.
### Peer health checks
To make our load balancer more reliable, we would like to add some health checks
to our upstream peers. That way if there is a peer that has gone down, we can
quickly stop routing our traffic to that peer.
First let's see how our simple load balancer behaves when one of the peers is
down. To do this, we'll update the list of peers to include a peer that is
guaranteed to be broken.
```rust
fn main() {
// ...
let upstreams =
LoadBalancer::try_from_iter(["1.1.1.1:443", "1.0.0.1:443", "127.0.0.1:343"]).unwrap();
// ...
}
```
Now if we run our loud balancer again with `cargo run`, and test it with
```
curl 127.0.0.1:6188 -svo /dev/null
```
We can see that one in every 3 request fails with `502: Bad Gateway`. This is
because our peer selection is strictly following the `RoundRobin` selection
pattern we gave it with no consideration to whether that peer is healthy. We can
fix this by adding a basic health check service.
```rust
fn main() {
let mut my_server = Server::new(None).unwrap();
my_server.bootstrap();
// Note that upstreams needs to be declared as `mut` now
let mut upstreams =
LoadBalancer::try_from_iter(["1.1.1.1:443", "1.0.0.1:443", "127.0.0.1:343"]).unwrap();
let hc = TcpHealthCheck::new();
upstreams.set_health_check(hc);
upstreams.health_check_frequency = Some(std::time::Duration::from_secs(1));
let background = background_service("health check", upstreams);
let upstreams = background.task();
// `upstreams` no longer need to be wrapped in an arc
let mut lb = http_proxy_service(&my_server.configuration, LB(upstreams));
lb.add_tcp("0.0.0.0:6188");
my_server.add_service(background);
my_server.add_service(lb);
my_server.run_forever();
}
```
Now if we again run and test our load balancer, we see that all requests
succeed and the broken peer is never used. Based on the configuration we used,
if that peer were to become healthy again, it would be re-included in the round
robin again in within 1 second.
### Command line options
The pingora `Server` type provides a lot of built-in functionality that we can
take advantage of with single-line change.
```rust
fn main() {
let mut my_server = Server::new(Some(Opt::default())).unwrap();
...
}
```
With this change, the command-line arguments passed to our load balancer will be
consumed by Pingora. We can test this by running:
```
cargo run -- -h
```
We should see a help menu with the list of arguments now available to us. We
will take advantage of those in the next sections to do more with our load
balancer for free
### Running in the background
Passing the parameter `-d` or `--daemon` will tell the program to run in the background.
```
cargo run -- -d
```
To stop this service, you can send `SIGTERM` signal to it for a graceful shutdown, in which the service will stop accepting new request but try to finish all ongoing requests before exiting.
```
pkill -SIGTERM load_balancer
```
(`SIGTERM` is the default signal for `pkill`.)
### Configurations
Pingora configuration files help define how to run the service. Here is an
example config file that defines how many threads the service can have, the
location of the pid file, the error log file, and the upgrade coordination
socket (which we will explain later). Copy the contents below and put them into
a file called `conf.yaml` in your `load_balancer` project directory.
```yaml
---
version: 1
threads: 2
pid_file: /tmp/load_balancer.pid
error_log: /tmp/load_balancer_err.log
upgrade_sock: /tmp/load_balancer.sock
```
To use this conf file:
```
RUST_LOG=INFO cargo run -- -c conf.yaml -d
```
`RUST_LOG=INFO` is here so that the service actually populate the error log.
Now you can find the pid of the service.
```
cat /tmp/load_balancer.pid
```
### Gracefully upgrade the service
(Linux only)
Let's say we changed the code of the load balancer, recompiled the binary. Now we want to upgrade the service running in the background to this newer version.
If we simply stop the old service, then start the new one, some request arriving in between could be lost. Fortunately, Pingora provides a graceful way to upgrade the service.
This is done by, first, send `SIGQUIT` signal to the running server, and then start the new server with the parameter `-u` \ `--upgrade`.
```
pkill -SIGQUIT load_balancer &&\
RUST_LOG=INFO cargo run -- -c conf.yaml -d -u
```
In this process, The old running server will wait and hand over its listening sockets to the new server. Then the old server runs until all its ongoing requests finish.
From a client's perspective, the service is always running because the listening socket is never closed.
## Full examples
The full code for this example is available in this repository under
[pingora-proxy/examples/load_balancer.rs](../pingora-proxy/examples/load_balancer.rs)
Other examples that you may find helpful are also available here
[pingora-proxy/examples/](../pingora-proxy/examples/)
[pingora/examples](../pingora/examples/)

33
docs/user_guide/conf.md Normal file
View file

@ -0,0 +1,33 @@
# Configuration
A Pingora configuration file is a list of Pingora settings in yaml format.
Example
```yaml
---
version: 1
threads: 2
pid_file: /run/pingora.pid
upgrade_sock: /tmp/pingora_upgrade.sock
user: nobody
group: webusers
```
## Settings
| Key | meaning | value type |
| ------------- |-------------| ----|
| version | the version of the conf, currently it is a constant `1` | number |
| pid_file | The path to the pid file | string |
| daemon | whether to run the server in the background | bool |
| error_log | the path to error log output file. STDERR is used if not set | string |
| upgrade_sock | the path to the upgrade socket. | string |
| threads | number of threads per service | number |
| user | the user the pingora server should be run under after daemonization | string |
| group | the group the pingora server should be run under after daemonization | string |
| client_bind_to_ipv4 | source IPv4 addresses to bind to when connecting to server | list of string |
| client_bind_to_ipv6 | source IPv6 addresses to bind to when connecting to server| list of string |
| ca_file | The path to the root CA file | string |
| work_stealing | Enable work stealing runtime (default true). See Pingora runtime (WIP) section for more info | bool |
| upstream_keepalive_pool_size | The number of total connections to keep in the connetion pool | number |
## Extension
Any unknown settings will be ignored. This allows extending the conf file to add and pass user defined settings. See User defined configuration section.

116
docs/user_guide/ctx.md Normal file
View file

@ -0,0 +1,116 @@
# Sharing state across phases with `CTX`
## Using `CTX`
The custom filters users implement in different phases of the request don't interact with each other directly. In order to share information and state across the filters, users can define a `CTX` struct. Each request owns a single `CTX` object. All the filters are able to read and update members of the `CTX` object. The CTX object will be dropped at the end of the request.
### Example
In the following example, the proxy parses the request header in the `request_filter` phase, it stores the boolean flag so that later in the `upstream_peer` phase the flag is used to decide which server to route traffic to. (Technically, the header can be parsed in `upstream_peer` phase, but we just do it in an earlier phase just for the demonstration.)
```Rust
pub struct MyProxy();
pub struct MyCtx {
beta_user: bool,
}
fn check_beta_user(req: &pingora_http::RequestHeader) -> bool {
// some simple logic to check if user is beta
req.headers.get("beta-flag").is_some()
}
#[async_trait]
impl ProxyHttp for MyProxy {
type CTX = MyCtx;
fn new_ctx(&self) -> Self::CTX {
MyCtx { beta_user: false }
}
async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
ctx.beta_user = check_beta_user(session.req_header());
Ok(false)
}
async fn upstream_peer(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
let addr = if ctx.beta_user {
info!("I'm a beta user");
("1.0.0.1", 443)
} else {
("1.1.1.1", 443)
};
let peer = Box::new(HttpPeer::new(addr, true, "one.one.one.one".to_string()));
Ok(peer)
}
}
```
## Sharing state across requests
Sharing state such as a counter, cache and other info across requests is common. There is nothing special needed for sharing resources and data across requests in Pingora. `Arc`, `static` or any other mechanism can be used.
### Example
Let's modify the example above to track the number of beta visitors as well as the number of total visitors. The counters can either be defined in the `MyProxy` struct itself or defined as a global variable. Because the counters can be concurrently accessed, Mutex is used here.
```Rust
// global counter
static REQ_COUNTER: Mutex<usize> = Mutex::new(0);
pub struct MyProxy {
// counter for the service
beta_counter: Mutex<usize>, // AtomicUsize works too
}
pub struct MyCtx {
beta_user: bool,
}
fn check_beta_user(req: &pingora_http::RequestHeader) -> bool {
// some simple logic to check if user is beta
req.headers.get("beta-flag").is_some()
}
#[async_trait]
impl ProxyHttp for MyProxy {
type CTX = MyCtx;
fn new_ctx(&self) -> Self::CTX {
MyCtx { beta_user: false }
}
async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
ctx.beta_user = check_beta_user(session.req_header());
Ok(false)
}
async fn upstream_peer(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
let mut req_counter = REQ_COUNTER.lock().unwrap();
*req_counter += 1;
let addr = if ctx.beta_user {
let mut beta_count = self.beta_counter.lock().unwrap();
*beta_count += 1;
info!("I'm a beta user #{beta_count}");
("1.0.0.1", 443)
} else {
info!("I'm an user #{req_counter}");
("1.1.1.1", 443)
};
let peer = Box::new(HttpPeer::new(addr, true, "one.one.one.one".to_string()));
Ok(peer)
}
}
```
The complete example can be found under [`pingora-proxy/examples/ctx.rs`](../../pingora-proxy/examples/ctx.rs). You can run it using `cargo`:
```
RUST_LOG=INFO cargo run --example ctx
```

View file

@ -0,0 +1,7 @@
# Daemonization
When a Pingora server is configured to run as a daemon, after its bootstrapping, it will move itself to the background and optionally change to run under the configured user and group. The `pid_file` option comes handy in this case for the user to track the PID of the daemon in the background.
Daemonization also allows the server to perform privileged actions like loading secrets and then switch to an unprivileged user before accepting any requests from the network.
This process happens in the `run_forever()` call. Because daemonization involves `fork()`, certain things like threads created before this call are likely lost.

View file

@ -0,0 +1,13 @@
# Error logging
Pingora libraries are built to expect issues like disconnects, timeouts and invalid inputs from the network. A common way to record these issues are to output them in error log (STDERR or log files).
## Log level guidelines
Pingora adopts the idea behind [log](https://docs.rs/log/latest/log/). There are five log levels:
* `error`: This level should be used when the error stops the request from being handled correctly. For example when the server we try to connect to is offline.
* `warning`: This level should be used when an error occurs but the system recovers from it. For example when the primary DNS timed out but the system is able to query the secondary DNS.
* `info`: Pingora logs when the server is starting up or shuting down.
* `debug`: Internal details. This log level is not compiled in `release` builds.
* `trace`: Fine-grained internal details. This log level is not compiled in `release` builds.
The pingora-proxy crate has a well-defined interface to log errors, so that users don't have to manually log common proxy errors. See its guide for more details.

53
docs/user_guide/errors.md Normal file
View file

@ -0,0 +1,53 @@
# How to return errors
For easy error handling, the `pingora-error` crate exports a custom `Result` type used throughout other Pingora crates.
The `Error` struct used in this `Result`'s error variant is a wrapper around arbitrary error types. It allows the user to tag the source of the underlying error and attach other custom context info.
Users will often need to return errors by propagating an existing error or creating a wholly new one. `pingora-error` makes this easy with its error building functions.
## Examples
For example, one could return an error when an expected header is not present:
```rust
fn validate_req_header(req: &RequestHeader) -> Result<()> {
// validate that the `host` header exists
req.headers()
.get(http::header::HOST)
.ok_or_else(|| Error::explain(InvalidHTTPHeader, "No host header detected"))
}
impl MyServer {
pub async fn handle_request_filter(
&self,
http_session: &mut Session,
ctx: &mut CTX,
) -> Result<bool> {
validate_req_header(session.req_header()?).or_err(HTTPStatus(400), "Missing required headers")?;
Ok(true)
}
}
```
`validate_req_header` returns an `Error` if the `host` header is not found, using `Error::explain` to create a new `Error` along with an associated type (`InvalidHTTPHeader`) and helpful context that may be logged in an error log.
This error will eventually propagate to the request filter, where it is returned as a new `HTTPStatus` error using `or_err`. (As part of the default pingora-proxy `fail_to_proxy()` phase, not only will this error be logged, but it will result in sending a `400 Bad Request` response downstream.)
Note that the original causing error will be visible in the error logs as well. `or_err` wraps the original causing error in a new one with additional context, but `Error`'s `Display` implementation also prints the chain of causing errors.
## Guidelines
An error has a _type_ (e.g. `ConnectionClosed`), a _source_ (e.g. `Upstream`, `Downstream`, `Internal`), and optionally, a _cause_ (another wrapped error) and a _context_ (arbitrary user-provided string details).
A minimal error can be created using functions like `new_in` / `new_up` / `new_down`, each of which specifies a source and asks the user to provide a type.
Generally speaking:
* To create a new error, without a direct cause but with more context, use `Error::explain`. You can also use `explain_err` on a `Result` to replace the potential error inside it with a new one.
* To wrap a causing error in a new one with more context, use `Error::because`. You can also use `or_err` on a `Result` to replace the potential error inside it by wrapping the original one.
## Retry
Errors can be "retry-able." If the error is retry-able, pingora-proxy will be allowed to retry the upstream request. Some errors are only retry-able on [reused connections](pooling.md), e.g. to handle situations where the remote end has dropped a connection we attempted to reuse.
By default a newly created `Error` either takes on its direct causing error's retry status, or, if left unspecified, is considered not retry-able.

View file

@ -0,0 +1,67 @@
# Handling failures and failover
Pingora-proxy allows users to define how to handle failures throughout the life of a proxied request.
When a failure happens before the response header is sent downstream, users have a few options:
1. Send an error page downstream and then give up.
2. Retry the same upstream again.
3. Try another upstream if applicable.
Otherwise, once the response header is already sent downstream, there is nothing the proxy can do other than logging an error and then giving up on the request.
## Retry / Failover
In order to implement retry or failover, `fail_to_connect()` / `error_while_proxy()` needs to mark the error as "retry-able." For failover, `fail_to_connect() / error_while_proxy()` also needs to update the `CTX` to tell `upstream_peer()` not to use the same `Peer` again.
### Safety
In general, idempotent HTTP requests, e.g., `GET`, are safe to retry. Other requests, e.g., `POST`, are not safe to retry if the requests have already been sent. When `fail_to_connect()` is called, pingora-proxy guarantees that nothing was sent upstream. Users are not recommended to retry an non-idempotent request after `error_while_proxy()` unless they know the upstream server enough to know whether it is safe.
### Example
In the following example we set a `tries` variable on the `CTX` to track how many connection attempts we've made. When setting our peer in `upstream_peer` we check if `tries` is less than one and connect to 192.0.2.1. On connect failure we increment `tries` in `fail_to_connect` and set `e.set_retry(true)` which tells Pingora this a retryable error. On retry we enter `upstream_peer` again and this time connect to 1.1.1.1. If we're unable to connect to 1.1.1.1 we return a 502 since we only set `e.set_retry(true)` in `fail_to_connect` when `tries` is zero.
```Rust
pub struct MyProxy();
pub struct MyCtx {
tries: usize,
}
#[async_trait]
impl ProxyHttp for MyProxy {
type CTX = MyCtx;
fn new_ctx(&self) -> Self::CTX {
MyCtx { tries: 0 }
}
fn fail_to_connect(
&self,
_session: &mut Session,
_peer: &HttpPeer,
ctx: &mut Self::CTX,
mut e: Box<Error>,
) -> Box<Error> {
if ctx.tries > 0 {
return e;
}
ctx.tries += 1;
e.set_retry(true);
e
}
async fn upstream_peer(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
let addr = if ctx.tries < 1 {
("192.0.2.1", 443)
} else {
("1.1.1.1", 443)
};
let mut peer = Box::new(HttpPeer::new(addr, true, "one.one.one.one".to_string()));
peer.options.connection_timeout = Some(Duration::from_millis(100));
Ok(peer)
}
}
```

View file

@ -0,0 +1,19 @@
# Graceful restart and shutdown
Graceful restart, upgrade, and shutdown mechanisms are very commonly used to avoid errors or downtime when releasing new versions of pingora servers.
Pingora graceful upgrade mechanism guarantees the following:
* A request is guaranteed to be handled either by the old server instance or the new one. No request will see connection refused when trying to connect to the server endpoints.
* A request that can finish within the grace period is guaranteed not to be terminated.
## How to graceful upgrade
### Step 0
Configure the upgrade socket. The old and new server need to agree on the same path to this socket. See configuration manual for details.
### Step 1
Start the new instance with the `--upgrade` cli option. The new instance will not try to listen to the service endpoint right away. It will try to acquire the listening socket from the old instance instead.
### Step 2
Send SIGQUIT signal to the old instance. The old instance will start to transfer the listening socket to the new instance.
Once step 2 is successful, the new instance will start to handle new incoming connections right away. Meanwhile, the old instance will enter its graceful shutdown mode. It waits a short period of time (to give the new instance time to initialize and prepare to handle traffic), after which it will not accept any new connections.

31
docs/user_guide/index.md Normal file
View file

@ -0,0 +1,31 @@
# User Guide
In this guide, we will cover the most used features, operations and settings of Pingora.
## Running Pingora servers
* [Start and stop](start_stop.md)
* [Graceful restart and graceful shutdown](graceful.md)
* [Configuration](conf.md)
* [Daemonization](daemon.md)
* [Systemd integration](systemd.md)
* [Handling panics](panic.md)
* [Error logging](error_log.md)
* [Prometheus](prom.md)
## Building HTTP proxies
* [Life of a request: `pingora-proxy` phases and filters](phase.md)
* [`Peer`: how to connect to upstream](peer.md)
* [Sharing state across phases with `CTX`](ctx.md)
* [How to return errors](errors.md)
* [Examples: take control of the request](modify_filter.md)
* [Connection pooling and reuse](pooling.md)
* [Handling failures and failover](failover.md)
## Advanced topics (WIP)
* [Pingora internals](internals.md)
* Using BoringSSL
* User defined configuration
* Pingora async runtime and threading model
* Background Service
* Blocking code in async context
* Tracing

View file

@ -0,0 +1,256 @@
# Pingora Internals
(Special thanks to [James Munns](https://github.com/jamesmunns) for writing this section)
## Starting the `Server`
The pingora system starts by spawning a *server*. The server is responsible for starting *services*, and listening for termination events.
```
┌───────────┐
┌─────────>│ Service │
│ └───────────┘
┌────────┐ │ ┌───────────┐
│ Server │──Spawns──┼─────────>│ Service │
└────────┘ │ └───────────┘
│ ┌───────────┐
└─────────>│ Service │
└───────────┘
```
After spawning the *services*, the server continues to listen to a termination event, which it will propagate to the created services.
## Services
*Services* are entities that handle listening to given sockets, and perform the core functionality. A *service* is tied to a particular protocol and set of options.
> NOTE: there are also "background" services, which just do *stuff*, and aren't necessarily listening to a socket. For now we're just talking about listener services.
Each service has its own threadpool/tokio runtime, with a number of threads based on the configured value. Worker threads are not shared cross-service. Service runtime threadpools may be work-stealing (tokio-default), or non-work-stealing (N isolated single threaded runtimes).
```
┌─────────────────────────┐
│ ┌─────────────────────┐ │
│ │┌─────────┬─────────┐│ │
│ ││ Conn │ Conn ││ │
│ │├─────────┼─────────┤│ │
│ ││Endpoint │Endpoint ││ │
│ │├─────────┴─────────┤│ │
│ ││ Listeners ││ │
│ │├─────────┬─────────┤│ │
│ ││ Worker │ Worker ││ │
│ ││ Thread │ Thread ││ │
│ │├─────────┴─────────┤│ │
│ ││ Tokio Executor ││ │
│ │└───────────────────┘│ │
│ └─────────────────────┘ │
│ ┌───────┐ │
└─┤Service├───────────────┘
└───────┘
```
## Service Listeners
At startup, each Service is assigned a set of downstream endpoints that they listen to. A single service may listen to more than one endpoint. The Server also passes along any relevant configuration, including TLS settings if relevant.
These endpoints are converted into listening sockets, called `TransportStack`s. Each `TransportStack` is assigned to an async task within that service's executor.
```
┌───────────────────┐
│┌─────────────────┐│ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─
┌─────────┐ ││ TransportStack ││ ┌────────────────────┐│
┌┤Listeners├────────┐ ││ ││ │ │ ││ │
│└─────────┘ │ ││ (Listener, TLS │├──────spawn(run_endpoint())────>│ Service<ServerApp> ││
│┌─────────────────┐│ ││ Acceptor, ││ │ │ ││ │
││ Endpoint ││ ││ UpgradeFDs) ││ └────────────────────┘│
││ addr/ports ││ │├─────────────────┤│ │ │ │
││ + TLS Settings ││ ││ TransportStack ││ ┌────────────────────┐│
│├─────────────────┤│ ││ ││ │ │ ││ │
││ Endpoint ││──build()─> ││ (Listener, TLS │├──────spawn(run_endpoint())────>│ Service<ServerApp> ││
││ addr/ports ││ ││ Acceptor, ││ │ │ ││ │
││ + TLS Settings ││ ││ UpgradeFDs) ││ └────────────────────┘│
│├─────────────────┤│ │├─────────────────┤│ │ │ │
││ Endpoint ││ ││ TransportStack ││ ┌────────────────────┐│
││ addr/ports ││ ││ ││ │ │ ││ │
││ + TLS Settings ││ ││ (Listener, TLS │├──────spawn(run_endpoint())────>│ Service<ServerApp> ││
│└─────────────────┘│ ││ Acceptor, ││ │ │ ││ │
└───────────────────┘ ││ UpgradeFDs) ││ └────────────────────┘│
│└─────────────────┘│ │ ┌───────────────┐ │ │ ┌──────────────┐
└───────────────────┘ ─│start_service()│─ ─ ─ ─│ Worker Tasks ├ ─ ─ ┘
└───────────────┘ └──────────────┘
```
## Downstream connection lifecycle
Each service processes incoming connections by spawning a task-per-connection. These connections are held open
as long as there are new events to be handled.
```
┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐
│ ┌───────────────┐ ┌────────────────┐ ┌─────────────────┐ ┌─────────────┐ │
┌────────────────────┐ │ UninitStream │ │ Service │ │ App │ │ Task Ends │
│ │ │ │ ::handshake() │──>│::handle_event()│──>│ ::process_new() │──┬>│ │ │
│ Service<ServerApp> │──spawn()──> └───────────────┘ └────────────────┘ └─────────────────┘ │ └─────────────┘
│ │ │ ▲ │ │
└────────────────────┘ │ while
│ └─────────reuse │
┌───────────────────────────┐
└ ─│ Task on Service Runtime │─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘
└───────────────────────────┘
```
## What is a proxy then?
Interestingly, the `pingora` `Server` itself has no particular notion of a Proxy.
Instead, it only thinks in terms of `Service`s, which are expected to contain a particular implementor of the `ServiceApp` trait.
For example, this is how an `HttpProxy` struct, from the `pingora-proxy` crate, "becomes" a `Service` spawned by the `Server`:
```
┌─────────────┐
│ HttpProxy │
│ (struct) │
└─────────────┘
implements ┌─────────────┐
│ │HttpServerApp│
└───────>│ (trait) │
└─────────────┘
implements ┌─────────────┐
│ │ ServerApp │
└───────>│ (trait) │
└─────────────┘
contained ┌─────────────────────┐
within │ │
└───────>│ Service<ServiceApp>
│ │
└─────────────────────┘
```
Different functionality and helpers are provided at different layers in this representation.
```
┌─────────────┐ ┌──────────────────────────────────────┐
│ HttpProxy │ │Handles high level Proxying workflow, │
│ (struct) │─ ─ ─ ─ │ customizable via ProxyHttp trait │
└──────┬──────┘ └──────────────────────────────────────┘
┌──────▼──────┐ ┌──────────────────────────────────────┐
│HttpServerApp│ │ Handles selection of H1 vs H2 stream │
│ (trait) │─ ─ ─ ─ │ handling, incl H2 handshake │
└──────┬──────┘ └──────────────────────────────────────┘
┌──────▼──────┐ ┌──────────────────────────────────────┐
│ ServerApp │ │ Handles dispatching of App instances │
│ (trait) │─ ─ ─ ─ │ as individual tasks, per Session │
└──────┬──────┘ └──────────────────────────────────────┘
┌──────▼──────┐ ┌──────────────────────────────────────┐
│ Service<A> │ │ Handles dispatching of App instances │
│ (struct) │─ ─ ─ ─ │ as individual tasks, per Listener │
└─────────────┘ └──────────────────────────────────────┘
```
The `HttpProxy` struct handles the high level workflow of proxying an HTTP connection
It uses the `ProxyHttp` (note the flipped wording order!) **trait** to allow customization
at each of the following steps (note: taken from [the phase chart](./phase_chart.md) doc):
```mermaid
graph TD;
start("new request")-->request_filter;
request_filter-->upstream_peer;
upstream_peer-->Connect{{IO: connect to upstream}};
Connect--connection success-->connected_to_upstream;
Connect--connection failure-->fail_to_connect;
connected_to_upstream-->upstream_request_filter;
upstream_request_filter --> SendReq{{IO: send request to upstream}};
SendReq-->RecvResp{{IO: read response from upstream}};
RecvResp-->upstream_response_filter-->response_filter-->upstream_response_body_filter-->response_body_filter-->logging-->endreq("request done");
fail_to_connect --can retry-->upstream_peer;
fail_to_connect --can't retry-->fail_to_proxy--send error response-->logging;
RecvResp--failure-->IOFailure;
SendReq--failure-->IOFailure;
error_while_proxy--can retry-->upstream_peer;
error_while_proxy--can't retry-->fail_to_proxy;
request_filter --send response-->logging
Error>any response filter error]-->error_while_proxy
IOFailure>IO error]-->error_while_proxy
```
## Zooming out
Before we zoom in, it's probably good to zoom out and remind ourselves how
a proxy generally works:
```
┌────────────┐ ┌─────────────┐ ┌────────────┐
│ Downstream │ │ Proxy │ │ Upstream │
│ Client │─────────>│ │────────>│ Server │
└────────────┘ └─────────────┘ └────────────┘
```
The proxy will be taking connections from the **Downstream** client, and (if
everything goes right), establishing a connection with the appropriate
**Upstream** server. This selected upstream server is referred to as
the **Peer**.
Once the connection is established, the Downstream and Upstream can communicate
bidirectionally.
So far, the discussion of Server, Services, and Listeners have focused on the LEFT
half of this diagram, handling incoming Downstream connections, and getting it TO
the proxy component.
Next, we'll look at the RIGHT half of this diagram, connecting to Upstreams.
## Managing the Upstream
Connections to Upstream Peers are made through `Connector`s. This is not a specific type or trait, but more
of a "style".
Connectors are responsible for a few things:
* Establishing a connection with a Peer
* Maintaining a connection pool with the Peer, allowing for connection reuse across:
* Multiple requests from a single downstream client
* Multiple requests from different downstream clients
* Measuring health of connections, for connections like H2, which perform regular pings
* Handling protocols with multiple poolable layers, like H2
* Caching, if relevant to the protocol and enabled
* Compression, if relevant to the protocol and enabled
Now in context, we can see how each end of the Proxy is handled:
```
┌────────────┐ ┌─────────────┐ ┌────────────┐
│ Downstream │ ┌ ─│─ Proxy ┌ ┼ ─ │ Upstream │
│ Client │─────────>│ │ │──┼─────>│ Server │
└────────────┘ │ └───────────┼─┘ └────────────┘
─ ─ ┘ ─ ─ ┘
▲ ▲
┌──┘ └──┐
│ │
┌ ─ ─ ─ ─ ┐ ┌ ─ ─ ─ ─ ─
Listeners Connectors│
└ ─ ─ ─ ─ ┘ └ ─ ─ ─ ─ ─
```
## What about multiple peers?
`Connectors` only handle the connection to a single peer, so selecting one of potentially multiple Peers
is actually handled one level up, in the `upstream_peer()` method of the `ProxyHttp` trait.

View file

@ -0,0 +1,133 @@
# Examples: taking control of the request
In this section we will go through how to route, modify or reject requests.
## Routing
Any information from the request can be used to make routing decision. Pingora doesn't impose any constraints on how users could implement their own routing logic.
In the following example, the proxy sends traffic to 1.0.0.1 only when the request path start with `/family/`. All the other requests are routed to 1.1.1.1.
```Rust
pub struct MyGateway;
#[async_trait]
impl ProxyHttp for MyGateway {
type CTX = ();
fn new_ctx(&self) -> Self::CTX {}
async fn upstream_peer(
&self,
session: &mut Session,
_ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
let addr = if session.req_header().uri.path().starts_with("/family/") {
("1.0.0.1", 443)
} else {
("1.1.1.1", 443)
};
info!("connecting to {addr:?}");
let peer = Box::new(HttpPeer::new(addr, true, "one.one.one.one".to_string()));
Ok(peer)
}
}
```
## Modifying headers
Both request and response headers can be added, removed or modified in their corresponding phases. In the following example, we add logic to the `response_filter` phase to update the `Server` header and remove the `alt-svc` header.
```Rust
#[async_trait]
impl ProxyHttp for MyGateway {
...
async fn response_filter(
&self,
_session: &mut Session,
upstream_response: &mut ResponseHeader,
_ctx: &mut Self::CTX,
) -> Result<()>
where
Self::CTX: Send + Sync,
{
// replace existing header if any
upstream_response
.insert_header("Server", "MyGateway")
.unwrap();
// because we don't support h3
upstream_response.remove_header("alt-svc");
Ok(())
}
}
```
## Return Error pages
Sometimes instead of proxying the traffic, under certain conditions, such as authentication failures, you might want the proxy to just return an error page.
```Rust
fn check_login(req: &pingora_http::RequestHeader) -> bool {
// implement you logic check logic here
req.headers.get("Authorization").map(|v| v.as_bytes()) == Some(b"password")
}
#[async_trait]
impl ProxyHttp for MyGateway {
...
async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result<bool> {
if session.req_header().uri.path().starts_with("/login")
&& !check_login(session.req_header())
{
let _ = session.respond_error(403).await;
// true: tell the proxy that the response is already written
return Ok(true);
}
Ok(false)
}
```
## Logging
Logging logic can be added to the `logging` phase of Pingora. The logging phase runs on every request right before Pingora proxy finish processing it. This phase runs for both successful and failed requests.
In the example below, we add Prometheus metric and access logging to the proxy. In order for the metrics to be scraped, we also start a Prometheus metric server on a different port.
``` Rust
pub struct MyGateway {
req_metric: prometheus::IntCounter,
}
#[async_trait]
impl ProxyHttp for MyGateway {
...
async fn logging(
&self,
session: &mut Session,
_e: Option<&pingora::Error>,
ctx: &mut Self::CTX,
) {
let response_code = session
.response_written()
.map_or(0, |resp| resp.status.as_u16());
// access log
info!(
"{} response code: {response_code}",
self.request_summary(session, ctx)
);
self.req_metric.inc();
}
fn main() {
...
let mut prometheus_service_http =
pingora::services::listening::Service::prometheus_http_service();
prometheus_service_http.add_tcp("127.0.0.1:6192");
my_server.add_service(prometheus_service_http);
my_server.run_forever();
}
```

10
docs/user_guide/panic.md Normal file
View file

@ -0,0 +1,10 @@
# Handling panics
Any panic that happens to particular requests does not affect other ongoing requests or the server's ability to handle other requests. Sockets acquired by the panicking requests are dropped (closed). The panics will be captured by the tokio runtime and then ignored.
In order to monitor the panics, Pingora server has built-in Sentry integration.
```rust
my_server.sentry = Some("SENTRY_DSN");
```
Even though a panic is not fatal in Pingora, it is still not the preferred way to handle failures like network timeouts. Panics should be reserved for unexpected logic errors.

35
docs/user_guide/peer.md Normal file
View file

@ -0,0 +1,35 @@
# `Peer`: how to connect to upstream
In the `upstream_peer()` phase the user should return a `Peer` object which defines how to connect to a certain upstream.
## `Peer`
A `HttpPeer` defines which upstream to connect to.
| attribute | meaning |
| ------------- |-------------|
|address: `SocketAddr`| The IP:Port to connect to |
|scheme: `Scheme`| Http or Https |
|sni: `String`| The SNI to use, Https only |
|proxy: `Option<Proxy>`| The setting to proxy the request through a [CONNECT proxy](https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/CONNECT) |
|client_cert_key: `Option<Arc<CertKey>>`| The client certificate to use in mTLS connections to upstream |
|options: `PeerOptions`| See below |
## `PeerOptions`
A `PeerOptions` defines how to connect to the upstream.
| attribute | meaning |
| ------------- |-------------|
|bind_to: `Option<InetSocketAddr>`| Which local address to bind to as the client IP |
|connection_timeout: `Option<Duration>`| How long to wait before giving up *establishing* a TCP connection |
|total_connection_timeout: `Option<Duration>`| How long to wait before giving up *establishing* a connection including TLS handshake time |
|read_timeout: `Option<Duration>`| How long to wait before each individual `read()` from upstream. The timer is reset after each `read()` |
|idle_timeout: `Option<Duration>`| How long to wait before closing a idle connection waiting for connetion reuse |
|write_timeout: `Option<Duration>`| How long to wait before a `write()` to upstream finishes |
|verify_cert: `bool`| Whether to check if upstream' server cert is valid and validated |
|verify_hostname: `bool`| Whether to check if upstream server cert's CN matches the SNI |
|alternative_cn: `Option<String>`| Accept the cert if the CN matches this name |
|alpn: `ALPN`| Which HTTP protocol to advertise during ALPN, http1.1 and/or http2 |
|ca: `Option<Arc<Box<[X509]>>>`| Which Root CA to use to validate the server's cert |
|tcp_keepalive: `Option<TcpKeepalive>`| TCP keepalive settings to upstream |
## Examples
TBD

126
docs/user_guide/phase.md Normal file
View file

@ -0,0 +1,126 @@
# Life of a request: pingora-proxy phases and filters
## Intro
The pingora-proxy HTTP proxy framework supports highly programmable proxy behaviors. This is done by allowing users to inject custom logic into different phases (stages) in the life of a request.
## Life of a proxied HTTP request
1. The life of a proxied HTTP request starts when the proxy reads the request header from the **downstream** (i.e., the client).
2. Then, the proxy connects to the **upstream** (i.e., the remote server). This step is skipped if there is a previously established [connection to reuse](pooling.md).
3. The proxy then sends the request header to the upstream.
4. Once the request header is sent, the proxy enters a duplex mode, which simultaneously proxies:
a. upstream response (both header and body) to the downstream, and
b. downstream request body to upstream (if any).
5. Once the entire request/response finishes, the life of the request is ended. All resources are released. The downstream connections and the upstream connections are recycled to be reused if applicable.
## Pingora-proxy phases and filters
Pingora-proxy allows users to insert arbitrary logic into the life of a request.
```mermaid
graph TD;
start("new request")-->request_filter;
request_filter-->upstream_peer;
upstream_peer-->Connect{{IO: connect to upstream}};
Connect--connection success-->connected_to_upstream;
Connect--connection failure-->fail_to_connect;
connected_to_upstream-->upstream_request_filter;
upstream_request_filter --> SendReq{{IO: send request to upstream}};
SendReq-->RecvResp{{IO: read response from upstream}};
RecvResp-->upstream_response_filter-->response_filter-->upstream_response_body_filter-->response_body_filter-->logging-->endreq("request done");
fail_to_connect --can retry-->upstream_peer;
fail_to_connect --can't retry-->fail_to_proxy--send error response-->logging;
RecvResp--failure-->IOFailure;
SendReq--failure-->IOFailure;
error_while_proxy--can retry-->upstream_peer;
error_while_proxy--can't retry-->fail_to_proxy;
request_filter --send response-->logging
Error>any response filter error]-->error_while_proxy
IOFailure>IO error]-->error_while_proxy
```
### General filter usage guidelines
* Most filters return a [`pingora_error::Result<_>`](errors.md). When the returned value is `Result::Err`, `fail_to_proxy()` will be called and the request will be terminated.
* Most filters are async functions, which allows other async operations such as IO to be performed within the filters.
* A per-request `CTX` object can be defined to share states across the filters of the same request. All filters have mutable access to this object.
* Most filters are optional.
* The reason both `upstream_response_*_filter()` and `response_*_filter()` exist is for HTTP caching integration reasons (still WIP).
### `request_filter()`
This is the first phase of every request.
This phase is usually for validating request inputs, rate limiting, and initializing context.
### `proxy_upstream_filter()`
This phase determines if we should continue to the upstream to serve a response. If we short-circuit, a 502 is returned by default, but a different response can be implemented.
This phase returns a boolean determining if we should continue to the upstream or error.
### `upstream_peer()`
This phase decides which upstream to connect to (e.g. with DNS lookup and hashing/round-robin), and how to connect to it.
This phase returns a `Peer` that defines the upstream to connect to. Implementing this phase is **required**.
### `connected_to_upstream()`
This phase is executed when upstream is successfully connected.
Usually this phase is for logging purposes. Connection info such as RTT and upstream TLS ciphers are reported in this phase.
### `fail_to_connect()`
The counterpart of `connected_to_upstream()`. This phase is called if an error is encountered when connecting to upstream.
In this phase users can report the error in Sentry/Prometheus/error log. Users can also decide if the error is retry-able.
If the error is retry-able, `upstream_peer()` will be called again, in which case the user can decide whether to retry the same upstream or failover to a secondary one.
If the error is not retry-able, the request will end.
### `upstream_request_filter()`
This phase is to modify requests before sending to upstream.
### `upstream_response_filter()/upstream_response_body_filter()`
This phase is triggered after an upstream response header/body is received.
This phase is to modify response headers (or body) before sending to downstream. Note that this phase is called _prior_ to HTTP caching and therefore any changes made here will affect the response stored in the HTTP cache.
### `response_filter()/response_body_filter()/response_trailer_filter()`
This phase is triggered after a response header/body/trailer is ready to send to downstream.
This phase is to modify them before sending to downstream.
### `error_while_proxy()`
This phase is triggered during proxy errors to upstream, this is after the connection is established.
This phase may decide to retry a request if the connection was re-used and the HTTP method is idempotent.
### `fail_to_proxy()`
This phase is called whenever an error is encounter during any of the phases above.
This phase is usually for error logging and error reporting to downstream.
### `logging()`
This is the last phase that runs after the request is finished (or errors) and before any of its resources are released. Every request will end up in this final phase.
This phase is usually for logging and post request cleanup.
### `request_summary()`
This is not a phase, but a commonly used callback.
Every error that reaches `fail_to_proxy()` will be automatically logged in the error log. `request_summary()` will be called to dump the info regarding the request when logging the error.
This callback returns a string which allows users to customize what info to dump in the error log to help track and debug the failures.
### `suppress_error_log()`
This is also not a phase, but another callback.
`fail_to_proxy()` errors are automatically logged in the error log, but users may not be interested in every error. For example, downstream errors are logged if the client disconnects early, but these errors can become noisy if users are mainly interested in observing upstream issues. This callback can inspect the error and returns true or false. If true, the error will not be written to the log.
### Cache filters
To be documented

View file

@ -0,0 +1,30 @@
Pingora proxy phases without caching
```mermaid
graph TD;
start("new request")-->request_filter;
request_filter-->upstream_peer;
upstream_peer-->Connect{{IO: connect to upstream}};
Connect--connection success-->connected_to_upstream;
Connect--connection failure-->fail_to_connect;
connected_to_upstream-->upstream_request_filter;
upstream_request_filter --> SendReq{{IO: send request to upstream}};
SendReq-->RecvResp{{IO: read response from upstream}};
RecvResp-->upstream_response_filter-->response_filter-->upstream_response_body_filter-->response_body_filter-->logging-->endreq("request done");
fail_to_connect --can retry-->upstream_peer;
fail_to_connect --can't retry-->fail_to_proxy--send error response-->logging;
RecvResp--failure-->IOFailure;
SendReq--failure-->IOFailure;
error_while_proxy--can retry-->upstream_peer;
error_while_proxy--can't retry-->fail_to_proxy;
request_filter --send response-->logging
Error>any response filter error]-->error_while_proxy
IOFailure>IO error]-->error_while_proxy
```

View file

@ -0,0 +1,22 @@
# Connection pooling and reuse
When the request to a `Peer` (upstream server) is finished, the connection to that peer is kept alive and added to a connection pool to be _reused_ by subsequent requests. This happens automatically without any special configuration.
Requests that reuse previously established connections avoid the latency and compute cost of setting up a new connection, improving the Pingora server's overall performance and scalability.
## Same `Peer`
Only the connections to the exact same `Peer` can be reused by a request. For correctness and security reasons, two `Peer`s are the same if and only if all the following attributes are the same
* IP:port
* scheme
* SNI
* client cert
* verify cert
* verify hostname
* alternative_cn
* proxy settings
## Disable pooling
To disable connection pooling and reuse to a certain `Peer`, just set the `idle_timeout` to 0 seconds to all requests using that `Peer`.
## Failure
A connection is considered not reusable if errors happen during the request.

22
docs/user_guide/prom.md Normal file
View file

@ -0,0 +1,22 @@
# Prometheus
Pingora has a built-in prometheus HTTP metric server for scraping.
```rust
...
let mut prometheus_service_http = Service::prometheus_http_service();
prometheus_service_http.add_tcp("0.0.0.0:1234");
my_server.add_service(prometheus_service_http);
my_server.run_forever();
```
The simplest way to use it is to have [static metrics](https://docs.rs/prometheus/latest/prometheus/#static-metrics).
```rust
static MY_COUNTER: Lazy<IntGauge> = Lazy::new(|| {
register_int_gauge!("my_counter", "my counter").unwrap()
});
```
This static metric will automatically appear in the Prometheus metric endpoint.

View file

@ -0,0 +1,27 @@
# Starting and stoping Pingora server
A pingora server is a regular unprivileged multithreaded process.
## Start
By default, the server will run in the foreground.
A Pingora server by default takes the following command-line arguments:
| Argument | Effect | default|
| ------------- |-------------| ----|
| -d, --daemon | Daemonize the server | false |
| -t, --test | Test the server conf and then exit (WIP) | false |
| -c, --conf | The path to the configuarion file | empty string |
| -u, --upgrade | This server should gracefully upgrade a running server | false |
## Stop
A Pingora server will listen to the following signals.
### SIGINT: fast shutdown
Upon receiving SIGINT (ctrl + c), the server will exit immediately with no delay. All unfinished requests will be interrupted. This behavior is usually less preferred because it could break requests.
### SIGTERM: graceful shutdown
Upon receiving SIGTERM, the server will notify all its services to shutdown, wait for some preconfigured time and then exit. This behavior gives requests a grace period to finish.
### SIGQUIT: graceful upgrade
Similar to SIGQUIT, but the server will also transfer all its listening sockets to a new Pingora server so that there is no downtime during the upgrade. See the [graceful upgrade](graceful.md) section for more details.

View file

@ -0,0 +1,14 @@
# Systemd integration
A Pingora server doesn't depend on systemd but it can easily be made into a systemd service.
```ini
[Service]
Type=forking
PIDFile=/run/pingora.pid
ExecStart=/bin/pingora -d -c /etc/pingora.conf
ExecReload=kill -QUIT $MAINPID
ExecReload=/bin/pingora -u -d -c /etc/pingora.conf
```
The example systemd setup integrates Pingora's graceful upgrade into systemd. To upgrade the pingora service, simply install a version of the binary and then call `systemctl reload pingora.service`.

View file

@ -0,0 +1,36 @@
[package]
name = "pingora-boringssl"
version = "0.1.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
repository = "https://github.com/cloudflare/pingora"
categories = ["asynchronous", "network-programming"]
keywords = ["async", "tls", "ssl", "pingora"]
description = """
BoringSSL async APIs for Pingora.
"""
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "pingora_boringssl"
path = "src/lib.rs"
[dependencies]
boring = { version = "4.5", features = ["pq-experimental"] }
boring-sys = "4.5"
futures-util = { version = "0.3", default-features = false }
tokio = { workspace = true, features = ["io-util", "net", "macros", "rt-multi-thread"] }
libc = "0.2.70"
foreign-types-shared = { version = "0.3" }
[dev-dependencies]
tokio-test = "0.4"
tokio = { workspace = true, features = ["full"] }
[features]
default = []
pq_use_second_keyshare = []
# waiting for boring-rs release
read_uninit = []

202
pingora-boringssl/LICENSE Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -0,0 +1,305 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! This file reimplements tokio-boring with the [overhauled](https://github.com/sfackler/tokio-openssl/commit/56f6618ab619f3e431fa8feec2d20913bf1473aa)
//! tokio-openssl interface while the tokio APIs from official [boring] crate is not yet caught up to it.
use boring::error::ErrorStack;
use boring::ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef, SslStream as SslStreamCore};
use futures_util::future;
use std::fmt;
use std::io::{self, Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
struct StreamWrapper<S> {
stream: S,
context: usize,
}
impl<S> fmt::Debug for StreamWrapper<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.stream, fmt)
}
}
impl<S> StreamWrapper<S> {
/// # Safety
///
/// Must be called with `context` set to a valid pointer to a live `Context` object, and the
/// wrapper must be pinned in memory.
unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
debug_assert_ne!(self.context, 0);
let stream = Pin::new_unchecked(&mut self.stream);
let context = &mut *(self.context as *mut _);
(stream, context)
}
}
impl<S> Read for StreamWrapper<S>
where
S: AsyncRead,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let (stream, cx) = unsafe { self.parts() };
let mut buf = ReadBuf::new(buf);
match stream.poll_read(cx, &mut buf)? {
Poll::Ready(()) => Ok(buf.filled().len()),
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}
impl<S> Write for StreamWrapper<S>
where
S: AsyncWrite,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let (stream, cx) = unsafe { self.parts() };
match stream.poll_write(cx, buf) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
fn flush(&mut self) -> io::Result<()> {
let (stream, cx) = unsafe { self.parts() };
match stream.poll_flush(cx) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}
fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(e) => match e.code() {
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
_ => Poll::Ready(Err(e)),
},
}
}
/// An asynchronous version of [`boring::ssl::SslStream`].
#[derive(Debug)]
pub struct SslStream<S>(SslStreamCore<StreamWrapper<S>>);
impl<S: AsyncRead + AsyncWrite> SslStream<S> {
/// Like [`SslStream::new`](ssl::SslStream::new).
pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
SslStreamCore::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream)
}
/// Like [`SslStream::connect`](ssl::SslStream::connect).
pub fn poll_connect(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ssl::Error>> {
self.with_context(cx, |s| cvt_ossl(s.connect()))
}
/// A convenience method wrapping [`poll_connect`](Self::poll_connect).
pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
}
/// Like [`SslStream::accept`](ssl::SslStream::accept).
pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
self.with_context(cx, |s| cvt_ossl(s.accept()))
}
/// A convenience method wrapping [`poll_accept`](Self::poll_accept).
pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
}
/// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake).
pub fn poll_do_handshake(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), ssl::Error>> {
self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
}
/// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
}
// TODO: early data
}
impl<S> SslStream<S> {
/// Returns a shared reference to the `Ssl` object associated with this stream.
pub fn ssl(&self) -> &SslRef {
self.0.ssl()
}
/// Returns a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.0.get_ref().stream
}
/// Returns a mutable reference to the underlying stream.
pub fn get_mut(&mut self) -> &mut S {
&mut self.0.get_mut().stream
}
/// Returns a pinned mutable reference to the underlying stream.
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
}
fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
where
F: FnOnce(&mut SslStreamCore<StreamWrapper<S>>) -> R,
{
let this = unsafe { self.get_unchecked_mut() };
this.0.get_mut().context = ctx as *mut _ as usize;
let r = f(&mut this.0);
this.0.get_mut().context = 0;
r
}
}
#[cfg(feature = "read_uninit")]
impl<S> AsyncRead for SslStream<S>
where
S: AsyncRead + AsyncWrite,
{
fn poll_read(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.with_context(ctx, |s| {
// SAFETY: read_uninit does not de-initialize the buffer.
match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? {
Poll::Ready(nread) => {
unsafe {
buf.assume_init(nread);
}
buf.advance(nread);
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
})
}
}
#[cfg(not(feature = "read_uninit"))]
impl<S> AsyncRead for SslStream<S>
where
S: AsyncRead + AsyncWrite,
{
fn poll_read(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.with_context(ctx, |s| {
// This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though
// OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now.
let slice = unsafe {
let buf = buf.unfilled_mut();
std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast::<u8>(), buf.len())
};
match cvt(s.read(slice))? {
Poll::Ready(nread) => {
unsafe {
buf.assume_init(nread);
}
buf.advance(nread);
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
})
}
}
impl<S> AsyncWrite for SslStream<S>
where
S: AsyncRead + AsyncWrite,
{
fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
self.with_context(ctx, |s| cvt(s.write(buf)))
}
fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
self.with_context(ctx, |s| cvt(s.flush()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
match self.as_mut().with_context(ctx, |s| s.shutdown()) {
Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
return Poll::Pending;
}
Err(e) => {
return Poll::Ready(Err(e
.into_io_error()
.unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
}
}
self.get_pin_mut().poll_shutdown(ctx)
}
}
#[tokio::test]
async fn test_google() {
use boring::ssl;
use std::net::ToSocketAddrs;
use std::pin::Pin;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
let addr = "8.8.8.8:443".to_socket_addrs().unwrap().next().unwrap();
let stream = TcpStream::connect(&addr).await.unwrap();
let ssl_context = ssl::SslContext::builder(ssl::SslMethod::tls())
.unwrap()
.build();
let ssl = ssl::Ssl::new(&ssl_context).unwrap();
let mut stream = crate::tokio_ssl::SslStream::new(ssl, stream).unwrap();
Pin::new(&mut stream).connect().await.unwrap();
stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await.unwrap();
let mut buf = vec![];
stream.read_to_end(&mut buf).await.unwrap();
let response = String::from_utf8_lossy(&buf);
let response = response.trim_end();
// any response code is fine
assert!(response.starts_with("HTTP/1.0 "));
assert!(response.ends_with("</html>") || response.ends_with("</HTML>"));
}

View file

@ -0,0 +1,192 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! the extended functionalities that are yet exposed via the [`boring`] APIs
use boring::error::ErrorStack;
use boring::pkey::{HasPrivate, PKeyRef};
use boring::ssl::{Ssl, SslAcceptor, SslRef};
use boring::x509::store::X509StoreRef;
use boring::x509::verify::X509VerifyParamRef;
use boring::x509::X509Ref;
use foreign_types_shared::ForeignTypeRef;
use libc::*;
use std::ffi::CString;
fn cvt(r: c_int) -> Result<c_int, ErrorStack> {
if r != 1 {
Err(ErrorStack::get())
} else {
Ok(r)
}
}
/// Add name as an additional reference identifier that can match the peer's certificate
///
/// See [X509_VERIFY_PARAM_set1_host](https://www.openssl.org/docs/man3.1/man3/X509_VERIFY_PARAM_set1_host.html).
pub fn add_host(verify_param: &mut X509VerifyParamRef, host: &str) -> Result<(), ErrorStack> {
if host.is_empty() {
return Ok(());
}
unsafe {
cvt(boring_sys::X509_VERIFY_PARAM_add1_host(
verify_param.as_ptr(),
host.as_ptr() as *const _,
host.len(),
))
.map(|_| ())
}
}
/// Set the verify cert store of `ssl`
///
/// See [SSL_set1_verify_cert_store](https://www.openssl.org/docs/man1.1.1/man3/SSL_set1_verify_cert_store.html).
pub fn ssl_set_verify_cert_store(
ssl: &mut SslRef,
cert_store: &X509StoreRef,
) -> Result<(), ErrorStack> {
unsafe {
cvt(boring_sys::SSL_set1_verify_cert_store(
ssl.as_ptr(),
cert_store.as_ptr(),
))?;
}
Ok(())
}
/// Load the certificate into `ssl`
///
/// See [SSL_use_certificate](https://www.openssl.org/docs/man1.1.1/man3/SSL_use_certificate.html).
pub fn ssl_use_certificate(ssl: &mut SslRef, cert: &X509Ref) -> Result<(), ErrorStack> {
unsafe {
cvt(boring_sys::SSL_use_certificate(ssl.as_ptr(), cert.as_ptr()))?;
}
Ok(())
}
/// Load the private key into `ssl`
///
/// See [SSL_use_certificate](https://www.openssl.org/docs/man1.1.1/man3/SSL_use_PrivateKey.html).
pub fn ssl_use_private_key<T>(ssl: &mut SslRef, key: &PKeyRef<T>) -> Result<(), ErrorStack>
where
T: HasPrivate,
{
unsafe {
cvt(boring_sys::SSL_use_PrivateKey(ssl.as_ptr(), key.as_ptr()))?;
}
Ok(())
}
/// Add the certificate into the cert chain of `ssl`
///
/// See [SSL_add1_chain_cert](https://www.openssl.org/docs/man1.1.1/man3/SSL_add1_chain_cert.html)
pub fn ssl_add_chain_cert(ssl: &mut SslRef, cert: &X509Ref) -> Result<(), ErrorStack> {
unsafe {
cvt(boring_sys::SSL_add1_chain_cert(ssl.as_ptr(), cert.as_ptr()))?;
}
Ok(())
}
/// Set renegotiation
///
/// This function is specific to BoringSSL
/// See <https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_set_renegotiate_mode>
pub fn ssl_set_renegotiate_mode_freely(ssl: &mut SslRef) {
unsafe {
boring_sys::SSL_set_renegotiate_mode(
ssl.as_ptr(),
boring_sys::ssl_renegotiate_mode_t::ssl_renegotiate_freely,
);
}
}
/// Set the curves/groups of `ssl`
///
/// See [set_groups_list](https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set1_curves.html).
pub fn ssl_set_groups_list(ssl: &mut SslRef, groups: &str) -> Result<(), ErrorStack> {
let groups = CString::new(groups).unwrap();
unsafe {
// somehow SSL_set1_groups_list doesn't exist but SSL_set1_curves_list means the same anyways
cvt(boring_sys::SSL_set1_curves_list(
ssl.as_ptr(),
groups.as_ptr(),
))?;
}
Ok(())
}
/// Set's whether a second keyshare to be sent in client hello when PQ is used.
///
/// Default is true. When `true`, the first PQ (if any) and none-PQ keyshares are sent.
/// When `false`, only the first configured keyshares are sent.
#[cfg(feature = "pq_use_second_keyshare")]
pub fn ssl_use_second_key_share(ssl: &mut SslRef, enabled: bool) {
unsafe { boring_sys::SSL_use_second_keyshare(ssl.as_ptr(), enabled as _) }
}
#[cfg(not(feature = "pq_use_second_keyshare"))]
pub fn ssl_use_second_key_share(_ssl: &mut SslRef, _enabled: bool) {}
/// Clear the error stack
///
/// SSL calls should check and clear the BoringSSL error stack. But some calls fail to do so.
/// This causes the next unrelated SSL call to fail due to the leftover errors. This function allow
/// the caller to clear the error stack before performing SSL calls to avoid this issue.
pub fn clear_error_stack() {
let _ = ErrorStack::get();
}
/// Create a new [Ssl] from &[SslAcceptor]
///
/// This function is needed because [Ssl::new()] doesn't take `&SslContextRef` like openssl-rs
pub fn ssl_from_acceptor(acceptor: &SslAcceptor) -> Result<Ssl, ErrorStack> {
Ssl::new_from_ref(acceptor.context())
}
/// Suspend the TLS handshake when a certificate is needed.
///
/// This function will cause tls handshake to pause and return the error: SSL_ERROR_WANT_X509_LOOKUP.
/// The caller should set the certificate and then call [unblock_ssl_cert()] before continue the
/// handshake on the tls connection.
pub fn suspend_when_need_ssl_cert(ssl: &mut SslRef) {
unsafe {
boring_sys::SSL_set_cert_cb(ssl.as_ptr(), Some(raw_cert_block), std::ptr::null_mut());
}
}
/// Unblock a TLS handshake after the certificate is set.
///
/// The user should continue to call tls handshake after this function is called.
pub fn unblock_ssl_cert(ssl: &mut SslRef) {
unsafe {
boring_sys::SSL_set_cert_cb(ssl.as_ptr(), None, std::ptr::null_mut());
}
}
// Just block the handshake
extern "C" fn raw_cert_block(_ssl: *mut boring_sys::SSL, _arg: *mut c_void) -> c_int {
-1
}
/// Whether the TLS error is SSL_ERROR_WANT_X509_LOOKUP
pub fn is_suspended_for_cert(error: &boring::ssl::Error) -> bool {
error.code().as_raw() == boring_sys::SSL_ERROR_WANT_X509_LOOKUP
}
#[allow(clippy::mut_from_ref)]
/// Get a mutable SslRef ouf of SslRef. which is a missing functionality for certain SslStream
/// # Safety
/// the caller need to make sure that they hold a &mut SslRef
pub unsafe fn ssl_mut(ssl: &SslRef) -> &mut SslRef {
unsafe { SslRef::from_ptr_mut(ssl.as_ptr()) }
}

View file

@ -0,0 +1,34 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! The BoringSSL API compatibility layer.
//!
//! This crate aims at making [boring] APIs exchangeable with [openssl-rs](https://docs.rs/openssl/latest/openssl/).
//! In other words, this crate and `pingora-openssl` expose identical rust APIs.
#![warn(clippy::all)]
use boring as ssl_lib;
pub use boring_sys as ssl_sys;
pub mod boring_tokio;
pub use boring_tokio as tokio_ssl;
pub mod ext;
// export commonly used libs
pub use ssl_lib::error;
pub use ssl_lib::hash;
pub use ssl_lib::nid;
pub use ssl_lib::pkey;
pub use ssl_lib::ssl;
pub use ssl_lib::x509;

64
pingora-cache/Cargo.toml Normal file
View file

@ -0,0 +1,64 @@
[package]
name = "pingora-cache"
version = "0.1.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
repository = "https://github.com/cloudflare/pingora"
categories = ["asynchronous", "network-programming"]
keywords = ["async", "http", "cache"]
description = """
HTTP caching APIs for Pingora proxy.
"""
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "pingora_cache"
path = "src/lib.rs"
[dependencies]
pingora-core = { version = "0.1.0", path = "../pingora-core" }
pingora-error = { version = "0.1.0", path = "../pingora-error" }
pingora-header-serde = { version = "0.1.0", path = "../pingora-header-serde" }
pingora-http = { version = "0.1.0", path = "../pingora-http" }
pingora-lru = { version = "0.1.0", path = "../pingora-lru" }
pingora-timeout = { version = "0.1.0", path = "../pingora-timeout" }
http = { workspace = true }
indexmap = "1"
once_cell = { workspace = true }
regex = "1"
blake2 = "0.10"
serde = { version = "1.0", features = ["derive"] }
rmp-serde = "1"
bytes = { workspace = true }
httpdate = "1.0.2"
log = { workspace = true }
async-trait = { workspace = true }
parking_lot = "0.12"
rustracing = "0.5.1"
rustracing_jaeger = "0.7"
rmp = "0.8"
tokio = { workspace = true }
lru = { workspace = true }
ahash = { workspace = true }
hex = "0.4"
httparse = { workspace = true }
[dev-dependencies]
tokio-test = "0.4"
tokio = { workspace = true, features = ["fs"] }
env_logger = "0.9"
dhat = "0"
futures = "0.3"
[[bench]]
name = "simple_lru_memory"
harness = false
[[bench]]
name = "lru_memory"
harness = false
[[bench]]
name = "lru_serde"
harness = false

202
pingora-cache/LICENSE Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -0,0 +1,96 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[global_allocator]
static ALLOC: dhat::Alloc = dhat::Alloc;
use pingora_cache::{
eviction::{lru::Manager, EvictionManager},
CacheKey,
};
const ITEMS: usize = 5 * usize::pow(2, 20);
/*
Total: 681,836,456 bytes (100%, 28,192,797.16/s) in 10,485,845 blocks (100%, 433,572.15/s), avg size 65.02 bytes, avg lifetime 5,935,075.17 µs (24.54% of program duration)
At t-gmax: 569,114,536 bytes (100%) in 5,242,947 blocks (100%), avg size 108.55 bytes
At t-end: 88 bytes (100%) in 3 blocks (100%), avg size 29.33 bytes
Allocated at {
#0: [root]
}
PP 1.1/5 {
Total: 293,601,280 bytes (43.06%, 12,139,921.91/s) in 5,242,880 blocks (50%, 216,784.32/s), avg size 56 bytes, avg lifetime 11,870,032.65 µs (49.08% of program duration)
Max: 293,601,280 bytes in 5,242,880 blocks, avg size 56 bytes
At t-gmax: 293,601,280 bytes (51.59%) in 5,242,880 blocks (100%), avg size 56 bytes
At t-end: 0 bytes (0%) in 0 blocks (0%), avg size 0 bytes
Allocated at {
#1: 0x5555703cf69c: alloc::alloc::exchange_malloc (alloc/src/alloc.rs:326:11)
#2: 0x5555703cf69c: alloc::boxed::Box<T>::new (alloc/src/boxed.rs:215:9)
#3: 0x5555703cf69c: pingora_lru::LruUnit<T>::admit (pingora-lru/src/lib.rs:201:20)
#4: 0x5555703cf69c: pingora_lru::Lru<T,_>::admit (pingora-lru/src/lib.rs:48:26)
#5: 0x5555703cf69c: <pingora_cache::eviction::lru::Manager<_> as pingora_cache::eviction::EvictionManager>::admit (src/eviction/lru.rs:114:9)
#6: 0x5555703cf69c: lru_memory::main (pingora-cache/benches/lru_memory.rs:78:9)
}
}
PP 1.2/5 {
Total: 203,685,456 bytes (29.87%, 8,422,052.97/s) in 50 blocks (0%, 2.07/s), avg size 4,073,709.12 bytes, avg lifetime 6,842,528.74 µs (28.29% of program duration)
Max: 132,906,576 bytes in 32 blocks, avg size 4,153,330.5 bytes
At t-gmax: 132,906,576 bytes (23.35%) in 32 blocks (0%), avg size 4,153,330.5 bytes
At t-end: 0 bytes (0%) in 0 blocks (0%), avg size 0 bytes
Allocated at {
#1: 0x5555703cec54: <alloc::alloc::Global as core::alloc::Allocator>::allocate (alloc/src/alloc.rs:237:9)
#2: 0x5555703cec54: alloc::raw_vec::RawVec<T,A>::allocate_in (alloc/src/raw_vec.rs:185:45)
#3: 0x5555703cec54: alloc::raw_vec::RawVec<T,A>::with_capacity_in (alloc/src/raw_vec.rs:131:9)
#4: 0x5555703cec54: alloc::vec::Vec<T,A>::with_capacity_in (src/vec/mod.rs:641:20)
#5: 0x5555703cec54: alloc::vec::Vec<T>::with_capacity (src/vec/mod.rs:483:9)
#6: 0x5555703cec54: pingora_lru::linked_list::Nodes::with_capacity (pingora-lru/src/linked_list.rs:50:25)
#7: 0x5555703cec54: pingora_lru::linked_list::LinkedList::with_capacity (pingora-lru/src/linked_list.rs:121:20)
#8: 0x5555703cec54: pingora_lru::LruUnit<T>::with_capacity (pingora-lru/src/lib.rs:176:20)
#9: 0x5555703cec54: pingora_lru::Lru<T,_>::with_capacity (pingora-lru/src/lib.rs:28:36)
#10: 0x5555703cec54: pingora_cache::eviction::lru::Manager<_>::with_capacity (src/eviction/lru.rs:22:17)
#11: 0x5555703cec54: lru_memory::main (pingora-cache/benches/lru_memory.rs:74:19)
}
}
PP 1.3/5 {
Total: 142,606,592 bytes (20.92%, 5,896,544.09/s) in 32 blocks (0%, 1.32/s), avg size 4,456,456 bytes, avg lifetime 22,056,252.88 µs (91.2% of program duration)
Max: 142,606,592 bytes in 32 blocks, avg size 4,456,456 bytes
At t-gmax: 142,606,592 bytes (25.06%) in 32 blocks (0%), avg size 4,456,456 bytes
At t-end: 0 bytes (0%) in 0 blocks (0%), avg size 0 bytes
Allocated at {
#1: 0x5555703ceb64: alloc::alloc::alloc (alloc/src/alloc.rs:95:14)
#2: 0x5555703ceb64: <hashbrown::raw::alloc::inner::Global as hashbrown::raw::alloc::inner::Allocator>::allocate (src/raw/alloc.rs:47:35)
#3: 0x5555703ceb64: hashbrown::raw::alloc::inner::do_alloc (src/raw/alloc.rs:62:9)
#4: 0x5555703ceb64: hashbrown::raw::RawTableInner<A>::new_uninitialized (src/raw/mod.rs:1080:38)
#5: 0x5555703ceb64: hashbrown::raw::RawTableInner<A>::fallible_with_capacity (src/raw/mod.rs:1109:30)
#6: 0x5555703ceb64: hashbrown::raw::RawTable<T,A>::fallible_with_capacity (src/raw/mod.rs:460:20)
#7: 0x5555703ceb64: hashbrown::raw::RawTable<T,A>::with_capacity_in (src/raw/mod.rs:481:15)
#8: 0x5555703ceb64: hashbrown::raw::RawTable<T>::with_capacity (src/raw/mod.rs:411:9)
#9: 0x5555703ceb64: hashbrown::map::HashMap<K,V,S>::with_capacity_and_hasher (hashbrown-0.12.3/src/map.rs:422:20)
#10: 0x5555703ceb64: hashbrown::map::HashMap<K,V>::with_capacity (hashbrown-0.12.3/src/map.rs:326:9)
#11: 0x5555703ceb64: pingora_lru::LruUnit<T>::with_capacity (pingora-lru/src/lib.rs:175:27)
#12: 0x5555703ceb64: pingora_lru::Lru<T,_>::with_capacity (pingora-lru/src/lib.rs:28:36)
#13: 0x5555703ceb64: pingora_cache::eviction::lru::Manager<_>::with_capacity (src/eviction/lru.rs:22:17)
#14: 0x5555703ceb64: lru_memory::main (pingora-cache/benches/lru_memory.rs:74:19)
}
}
*/
fn main() {
let _profiler = dhat::Profiler::new_heap();
let manager = Manager::<32>::with_capacity(ITEMS, ITEMS / 32);
let unused_ttl = std::time::SystemTime::now();
for i in 0..ITEMS {
let item = CacheKey::new("", i.to_string(), "").to_compact();
manager.admit(item, 1, unused_ttl);
}
}

View file

@ -0,0 +1,46 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Instant;
use pingora_cache::{
eviction::{lru::Manager, EvictionManager},
CacheKey,
};
const ITEMS: usize = 5 * usize::pow(2, 20);
fn main() {
let manager = Manager::<32>::with_capacity(ITEMS, ITEMS / 32);
let manager2 = Manager::<32>::with_capacity(ITEMS, ITEMS / 32);
let unused_ttl = std::time::SystemTime::now();
for i in 0..ITEMS {
let item = CacheKey::new("", i.to_string(), "").to_compact();
manager.admit(item, 1, unused_ttl);
}
/* lru serialize shard 19 22.573338ms, 5241623 bytes
* lru deserialize shard 19 39.260669ms, 5241623 bytes */
for i in 0..32 {
let before = Instant::now();
let ser = manager.serialize_shard(i).unwrap();
let elapsed = before.elapsed();
println!("lru serialize shard {i} {elapsed:?}, {} bytes", ser.len());
let before = Instant::now();
manager2.deserialize_shard(&ser).unwrap();
let elapsed = before.elapsed();
println!("lru deserialize shard {i} {elapsed:?}, {} bytes", ser.len());
}
}

View file

@ -0,0 +1,78 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[global_allocator]
static ALLOC: dhat::Alloc = dhat::Alloc;
use pingora_cache::{
eviction::{simple_lru::Manager, EvictionManager},
CacheKey,
};
const ITEMS: usize = 5 * usize::pow(2, 20);
/*
Total: 704,643,412 bytes (100%, 29,014,058.85/s) in 10,485,787 blocks (100%, 431,757.73/s), avg size 67.2 bytes, avg lifetime 6,163,799.09 µs (25.38% of program duration)
At t-gmax: 520,093,936 bytes (100%) in 5,242,886 blocks (100%), avg size 99.2 bytes
PP 1.1/4 {
Total: 377,487,360 bytes (53.57%, 15,543,238.31/s) in 5,242,880 blocks (50%, 215,878.31/s), avg size 72 bytes, avg lifetime 12,327,602.83 µs (50.76% of program duration)
Max: 377,487,360 bytes in 5,242,880 blocks, avg size 72 bytes
At t-gmax: 377,487,360 bytes (72.58%) in 5,242,880 blocks (100%), avg size 72 bytes
At t-end: 0 bytes (0%) in 0 blocks (0%), avg size 0 bytes
Allocated at {
#1: 0x5555791dd7e0: alloc::alloc::exchange_malloc (alloc/src/alloc.rs:326:11)
#2: 0x5555791dd7e0: alloc::boxed::Box<T>::new (alloc/src/boxed.rs:215:9)
#3: 0x5555791dd7e0: lru::LruCache<K,V,S>::replace_or_create_node (lru-0.8.1/src/lib.rs:391:20)
#4: 0x5555791dd7e0: lru::LruCache<K,V,S>::capturing_put (lru-0.8.1/src/lib.rs:355:44)
#5: 0x5555791dd7e0: lru::LruCache<K,V,S>::push (lru-0.8.1/src/lib.rs:334:9)
#6: 0x5555791dd7e0: pingora_cache::eviction::simple_lru::Manager::insert (src/eviction/simple_lru.rs:49:23)
#7: 0x5555791dd7e0: <pingora_cache::eviction::simple_lru::Manager as pingora_cache::eviction::EvictionManager>::admit (src/eviction/simple_lru.rs:166:9)
#8: 0x5555791dd7e0: simple_lru_memory::main (pingora-cache/benches/simple_lru_memory.rs:21:9)
}
}
PP 1.2/4 {
Total: 285,212,780 bytes (40.48%, 11,743,784.5/s) in 22 blocks (0%, 0.91/s), avg size 12,964,217.27 bytes, avg lifetime 1,116,774.23 µs (4.6% of program duration)
Max: 213,909,520 bytes in 2 blocks, avg size 106,954,760 bytes
At t-gmax: 142,606,344 bytes (27.42%) in 1 blocks (0%), avg size 142,606,344 bytes
At t-end: 0 bytes (0%) in 0 blocks (0%), avg size 0 bytes
Allocated at {
#1: 0x5555791dae20: alloc::alloc::alloc (alloc/src/alloc.rs:95:14)
#2: 0x5555791dae20: <hashbrown::raw::alloc::inner::Global as hashbrown::raw::alloc::inner::Allocator>::allocate (src/raw/alloc.rs:47:35)
#3: 0x5555791dae20: hashbrown::raw::alloc::inner::do_alloc (src/raw/alloc.rs:62:9)
#4: 0x5555791dae20: hashbrown::raw::RawTableInner<A>::new_uninitialized (src/raw/mod.rs:1080:38)
#5: 0x5555791dae20: hashbrown::raw::RawTableInner<A>::fallible_with_capacity (src/raw/mod.rs:1109:30)
#6: 0x5555791dae20: hashbrown::raw::RawTableInner<A>::prepare_resize (src/raw/mod.rs:1353:29)
#7: 0x5555791dae20: hashbrown::raw::RawTableInner<A>::resize_inner (src/raw/mod.rs:1426:29)
#8: 0x5555791dae20: hashbrown::raw::RawTableInner<A>::reserve_rehash_inner (src/raw/mod.rs:1403:13)
#9: 0x5555791dae20: hashbrown::raw::RawTable<T,A>::reserve_rehash (src/raw/mod.rs:680:13)
#10: 0x5555791dde50: hashbrown::raw::RawTable<T,A>::reserve (src/raw/mod.rs:646:16)
#11: 0x5555791dde50: hashbrown::raw::RawTable<T,A>::insert (src/raw/mod.rs:725:17)
#12: 0x5555791dde50: hashbrown::map::HashMap<K,V,S,A>::insert (hashbrown-0.12.3/src/map.rs:1679:13)
#13: 0x5555791dde50: lru::LruCache<K,V,S>::capturing_put (lru-0.8.1/src/lib.rs:361:17)
#14: 0x5555791dde50: lru::LruCache<K,V,S>::push (lru-0.8.1/src/lib.rs:334:9)
#15: 0x5555791dde50: pingora_cache::eviction::simple_lru::Manager::insert (src/eviction/simple_lru.rs:49:23)
#16: 0x5555791dde50: <pingora_cache::eviction::simple_lru::Manager as pingora_cache::eviction::EvictionManager>::admit (src/eviction/simple_lru.rs:166:9)
#17: 0x5555791dde50: simple_lru_memory::main (pingora-cache/benches/simple_lru_memory.rs:21:9)
}
}
*/
fn main() {
let _profiler = dhat::Profiler::new_heap();
let manager = Manager::new(ITEMS);
let unused_ttl = std::time::SystemTime::now();
for i in 0..ITEMS {
let item = CacheKey::new("", i.to_string(), "").to_compact();
manager.admit(item, 1, unused_ttl);
}
}

View file

@ -0,0 +1,839 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Functions and utilities to help parse Cache-Control headers
use super::*;
use http::header::HeaderName;
use http::HeaderValue;
use indexmap::IndexMap;
use once_cell::sync::Lazy;
use pingora_error::{Error, ErrorType, Result};
use pingora_http::ResponseHeader;
use regex::bytes::Regex;
use std::num::IntErrorKind;
use std::slice;
use std::str;
/// The max delta-second per [RFC 7234](https://datatracker.ietf.org/doc/html/rfc7234#section-1.2.1)
// "If a cache receives a delta-seconds
// value greater than the greatest integer it can represent, or if any
// of its subsequent calculations overflows, the cache MUST consider the
// value to be either 2147483648 (2^31) or the greatest positive integer
// it can conveniently represent."
pub const DELTA_SECONDS_OVERFLOW_VALUE: u32 = 2147483648;
/// Cache control directive key type
pub type DirectiveKey = String;
/// Cache control directive value type
#[derive(Debug)]
pub struct DirectiveValue(pub Vec<u8>);
impl AsRef<[u8]> for DirectiveValue {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl DirectiveValue {
/// A [DirectiveValue] without quotes (`"`).
pub fn parse_as_bytes(&self) -> &[u8] {
self.0
.strip_prefix(&[b'"'])
.and_then(|bytes| bytes.strip_suffix(&[b'"']))
.unwrap_or(&self.0[..])
}
/// A [DirectiveValue] without quotes (`"`) as `str`.
pub fn parse_as_str(&self) -> Result<&str> {
str::from_utf8(self.parse_as_bytes()).or_else(|e| {
Error::e_because(ErrorType::InternalError, "could not parse value as utf8", e)
})
}
/// Parse the [DirectiveValue] as delta seconds
///
/// `"`s are ignored. The value is capped to [DELTA_SECONDS_OVERFLOW_VALUE].
pub fn parse_as_delta_seconds(&self) -> Result<u32> {
match self.parse_as_str()?.parse::<u32>() {
Ok(value) => Ok(value),
Err(e) => {
// delta-seconds expect to handle positive overflow gracefully
if e.kind() == &IntErrorKind::PosOverflow {
Ok(DELTA_SECONDS_OVERFLOW_VALUE)
} else {
Error::e_because(ErrorType::InternalError, "could not parse value as u32", e)
}
}
}
}
}
/// An ordered map to store cache control key value pairs.
pub type DirectiveMap = IndexMap<DirectiveKey, Option<DirectiveValue>>;
/// Parsed Cache-Control directives
#[derive(Debug)]
pub struct CacheControl {
/// The parsed directives
pub directives: DirectiveMap,
}
/// Cacheability calculated from cache control.
#[derive(Debug, PartialEq, Eq)]
pub enum Cacheable {
/// Cacheable
Yes,
/// Not cacheable
No,
/// No directive found for explicit cacheability
Default,
}
/// An iter over all the cache control directives
pub struct ListValueIter<'a>(slice::Split<'a, u8, fn(&u8) -> bool>);
impl<'a> ListValueIter<'a> {
pub fn from(value: &'a DirectiveValue) -> Self {
ListValueIter(value.parse_as_bytes().split(|byte| byte == &b','))
}
}
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.3
// optional whitespace OWS = *(SP / HTAB); SP = 0x20, HTAB = 0x09
fn trim_ows(bytes: &[u8]) -> &[u8] {
fn not_ows(b: &u8) -> bool {
b != &b'\x20' && b != &b'\x09'
}
// find first non-OWS char from front (head) and from end (tail)
let head = bytes.iter().position(not_ows).unwrap_or(0);
let tail = bytes
.iter()
.rposition(not_ows)
.map(|rpos| rpos + 1)
.unwrap_or(head);
&bytes[head..tail]
}
impl<'a> Iterator for ListValueIter<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
Some(trim_ows(self.0.next()?))
}
}
/*
Originally from https://github.com/hapijs/wreck:
Cache-Control = 1#cache-directive
cache-directive = token [ "=" ( token / quoted-string ) ]
token = [^\x00-\x20\(\)<>@\,;\:\\"\/\[\]\?\=\{\}\x7F]+
quoted-string = "(?:[^"\\]|\\.)*"
*/
static RE_CACHE_DIRECTIVE: Lazy<Regex> =
// unicode support disabled, allow ; or , delimiter | capture groups: 1: directive = 2: token OR quoted-string
Lazy::new(|| {
Regex::new(r#"(?-u)(?:^|(?:\s*[,;]\s*))([^\x00-\x20\(\)<>@,;:\\"/\[\]\?=\{\}\x7F]+)(?:=((?:[^\x00-\x20\(\)<>@,;:\\"/\[\]\?=\{\}\x7F]+|(?:"(?:[^"\\]|\\.)*"))))?"#).unwrap()
});
impl CacheControl {
// Our parsing strategy is more permissive than the RFC in a few ways:
// - Allows semicolons as delimiters (in addition to commas).
// - Allows octets outside of visible ASCII in tokens.
// - Doesn't require no-value for "boolean directives," such as must-revalidate
// - Allows quoted-string format for numeric values.
fn from_headers(headers: http::header::GetAll<HeaderValue>) -> Option<Self> {
let mut directives = IndexMap::new();
// should iterate in header line insertion order
for line in headers {
for captures in RE_CACHE_DIRECTIVE.captures_iter(line.as_bytes()) {
// directive key
// header values don't have to be utf-8, but we store keys as strings for case-insensitive hashing
let key = captures.get(1).and_then(|cap| {
str::from_utf8(cap.as_bytes())
.ok()
.map(|token| token.to_lowercase())
});
if key.is_none() {
continue;
}
// directive value
// match token or quoted-string
let value = captures
.get(2)
.map(|cap| DirectiveValue(cap.as_bytes().to_vec()));
directives.insert(key.unwrap(), value);
}
}
Some(CacheControl { directives })
}
/// Parse from the given header name in `headers`
pub fn from_headers_named(header_name: &str, headers: &http::HeaderMap) -> Option<Self> {
if !headers.contains_key(header_name) {
return None;
}
Self::from_headers(headers.get_all(header_name))
}
/// Parse from the given header name in the [ReqHeader]
pub fn from_req_headers_named(header_name: &str, req_header: &ReqHeader) -> Option<Self> {
Self::from_headers_named(header_name, &req_header.headers)
}
/// Parse `Cache-Control` header name from the [ReqHeader]
pub fn from_req_headers(req_header: &ReqHeader) -> Option<Self> {
Self::from_req_headers_named("cache-control", req_header)
}
/// Parse from the given header name in the [RespHeader]
pub fn from_resp_headers_named(header_name: &str, resp_header: &RespHeader) -> Option<Self> {
Self::from_headers_named(header_name, &resp_header.headers)
}
/// Parse `Cache-Control` header name from the [RespHeader]
pub fn from_resp_headers(resp_header: &RespHeader) -> Option<Self> {
Self::from_resp_headers_named("cache-control", resp_header)
}
/// Whether the given directive is in the cache control.
pub fn has_key(&self, key: &str) -> bool {
self.directives.contains_key(key)
}
/// Whether the `public` directive is in the cache control.
pub fn public(&self) -> bool {
self.has_key("public")
}
/// Whether the given directive exists and it has no value.
fn has_key_without_value(&self, key: &str) -> bool {
matches!(self.directives.get(key), Some(None))
}
/// Whether the standalone `private` exists in the cache control
// RFC 7234: using the #field-name versions of `private`
// means a shared cache "MUST NOT store the specified field-name(s),
// whereas it MAY store the remainder of the response."
// It must be a boolean form (no value) to apply to the whole response.
// https://datatracker.ietf.org/doc/html/rfc7234#section-5.2.2.6
pub fn private(&self) -> bool {
self.has_key_without_value("private")
}
fn get_field_names(&self, key: &str) -> Option<ListValueIter> {
if let Some(Some(value)) = self.directives.get(key) {
Some(ListValueIter::from(value))
} else {
None
}
}
/// Get the values of `private=`
pub fn private_field_names(&self) -> Option<ListValueIter> {
self.get_field_names("private")
}
/// Whether the standalone `no-cache` exists in the cache control
pub fn no_cache(&self) -> bool {
self.has_key_without_value("no-cache")
}
/// Get the values of `no-cache=`
pub fn no_cache_field_names(&self) -> Option<ListValueIter> {
self.get_field_names("no-cache")
}
/// Whether `no-store` exists.
pub fn no_store(&self) -> bool {
self.has_key("no-store")
}
fn parse_delta_seconds(&self, key: &str) -> Result<Option<u32>> {
if let Some(Some(dir_value)) = self.directives.get(key) {
Ok(Some(dir_value.parse_as_delta_seconds()?))
} else {
Ok(None)
}
}
/// Return the `max-age` seconds
pub fn max_age(&self) -> Result<Option<u32>> {
self.parse_delta_seconds("max-age")
}
/// Return the `s-maxage` seconds
pub fn s_maxage(&self) -> Result<Option<u32>> {
self.parse_delta_seconds("s-maxage")
}
/// Return the `stale-while-revalidate` seconds
pub fn stale_while_revalidate(&self) -> Result<Option<u32>> {
self.parse_delta_seconds("stale-while-revalidate")
}
/// Return the `stale-if-error` seconds
pub fn stale_if_error(&self) -> Result<Option<u32>> {
self.parse_delta_seconds("stale-if-error")
}
/// Whether `must-revalidate` exists.
pub fn must_revalidate(&self) -> bool {
self.has_key("must-revalidate")
}
/// Whether `proxy-revalidate` exists.
pub fn proxy_revalidate(&self) -> bool {
self.has_key("proxy-revalidate")
}
/// Whether `only-if-cached` exists.
pub fn only_if_cached(&self) -> bool {
self.has_key("only-if-cached")
}
}
impl InterpretCacheControl for CacheControl {
fn is_cacheable(&self) -> Cacheable {
if self.no_store() || self.private() {
return Cacheable::No;
}
if self.has_key("s-maxage") || self.has_key("max-age") || self.public() {
return Cacheable::Yes;
}
Cacheable::Default
}
fn allow_caching_authorized_req(&self) -> bool {
// RFC 7234 https://datatracker.ietf.org/doc/html/rfc7234#section-3
// "MUST NOT" store requests with Authorization header
// unless response contains one of these directives
self.must_revalidate() || self.public() || self.has_key("s-maxage")
}
fn fresh_sec(&self) -> Option<u32> {
if self.no_cache() {
// always treated as stale
return Some(0);
}
match self.s_maxage() {
Ok(Some(seconds)) => Some(seconds),
// s-maxage not present
Ok(None) => match self.max_age() {
Ok(Some(seconds)) => Some(seconds),
_ => None,
},
_ => None,
}
}
fn serve_stale_while_revalidate_sec(&self) -> Option<u32> {
// RFC 7234: these directives forbid serving stale.
// https://datatracker.ietf.org/doc/html/rfc7234#section-4.2.4
if self.must_revalidate() || self.proxy_revalidate() || self.has_key("s-maxage") {
return Some(0);
}
self.stale_while_revalidate().unwrap_or(None)
}
fn serve_stale_if_error_sec(&self) -> Option<u32> {
if self.must_revalidate() || self.proxy_revalidate() || self.has_key("s-maxage") {
return Some(0);
}
self.stale_if_error().unwrap_or(None)
}
// Strip header names listed in `private` or `no-cache` directives from a response.
fn strip_private_headers(&self, resp_header: &mut ResponseHeader) {
fn strip_listed_headers(resp: &mut ResponseHeader, field_names: ListValueIter) {
for name in field_names {
if let Ok(header) = HeaderName::from_bytes(name) {
resp.remove_header(&header);
}
}
}
if let Some(headers) = self.private_field_names() {
strip_listed_headers(resp_header, headers);
}
// We interpret `no-cache` the same way as `private`,
// though technically it has a less restrictive requirement
// ("MUST NOT be sent in the response to a subsequent request
// without successful revalidation with the origin server").
// https://datatracker.ietf.org/doc/html/rfc7234#section-5.2.2.2
if let Some(headers) = self.no_cache_field_names() {
strip_listed_headers(resp_header, headers);
}
}
}
/// `InterpretCacheControl` provides a meaningful interface to the parsed `CacheControl`.
/// These functions actually interpret the parsed cache-control directives to return
/// the freshness or other cache meta values that cache-control is signaling.
///
/// By default `CacheControl` implements an RFC-7234 compliant reading that assumes it is being
/// used with a shared (proxy) cache.
pub trait InterpretCacheControl {
/// Does cache-control specify this response is cacheable?
///
/// Note that an RFC-7234 compliant cacheability check must also
/// check if the request contained the Authorization header and
/// `allow_caching_authorized_req`.
fn is_cacheable(&self) -> Cacheable;
/// Does this cache-control allow caching a response to
/// a request with the Authorization header?
fn allow_caching_authorized_req(&self) -> bool;
/// Returns freshness ttl specified in cache-control
///
/// - `Some(_)` indicates cache-control specifies a valid ttl. Some(0) = always stale.
/// - `None` means cache-control did not specify a valid ttl.
fn fresh_sec(&self) -> Option<u32>;
/// Returns stale-while-revalidate ttl,
///
/// The result should consider all the relevant cache directives, not just SWR header itself.
///
/// Some(0) means serving such stale is disallowed by directive like `must-revalidate`
/// or `stale-while-revalidater=0`.
///
/// `None` indicates no SWR ttl was specified.
fn serve_stale_while_revalidate_sec(&self) -> Option<u32>;
/// Returns stale-if-error ttl,
///
/// The result should consider all the relevant cache directives, not just SIE header itself.
///
/// Some(0) means serving such stale is disallowed by directive like `must-revalidate`
/// or `stale-if-error=0`.
///
/// `None` indicates no SIE ttl was specified.
fn serve_stale_if_error_sec(&self) -> Option<u32>;
/// Strip header names listed in `private` or `no-cache` directives from a response,
/// usually prior to storing that response in cache.
fn strip_private_headers(&self, resp_header: &mut ResponseHeader);
}
#[cfg(test)]
mod tests {
use super::*;
use http::header::CACHE_CONTROL;
use http::HeaderValue;
use http::{request, response};
fn build_response(cc_key: HeaderName, cc_value: &str) -> response::Parts {
let (parts, _) = response::Builder::new()
.header(cc_key, cc_value)
.body(())
.unwrap()
.into_parts();
parts
}
#[test]
fn test_simple_cache_control() {
let resp = build_response(CACHE_CONTROL, "public, max-age=10000");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.public());
assert_eq!(cc.max_age().unwrap().unwrap(), 10000);
}
#[test]
fn test_private_cache_control() {
let resp = build_response(CACHE_CONTROL, "private");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.private());
assert!(cc.max_age().unwrap().is_none());
}
#[test]
fn test_directives_across_header_lines() {
let (parts, _) = response::Builder::new()
.header(CACHE_CONTROL, "public,")
.header("cache-Control", "max-age=10000")
.body(())
.unwrap()
.into_parts();
let cc = CacheControl::from_resp_headers(&parts).unwrap();
assert!(cc.public());
assert_eq!(cc.max_age().unwrap().unwrap(), 10000);
}
#[test]
fn test_recognizes_semicolons_as_delimiters() {
let resp = build_response(CACHE_CONTROL, "public; max-age=0");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.public());
assert_eq!(cc.max_age().unwrap().unwrap(), 0);
}
#[test]
fn test_unknown_directives() {
let resp = build_response(CACHE_CONTROL, "public,random1=random2, rand3=\"\"");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
let mut directive_iter = cc.directives.iter();
let first = directive_iter.next().unwrap();
assert_eq!(first.0, &"public");
assert!(first.1.is_none());
let second = directive_iter.next().unwrap();
assert_eq!(second.0, &"random1");
assert_eq!(second.1.as_ref().unwrap().0, "random2".as_bytes());
let third = directive_iter.next().unwrap();
assert_eq!(third.0, &"rand3");
assert_eq!(third.1.as_ref().unwrap().0, "\"\"".as_bytes());
assert!(directive_iter.next().is_none());
}
#[test]
fn test_case_insensitive_directive_keys() {
let resp = build_response(
CACHE_CONTROL,
"Public=\"something\", mAx-AGe=\"10000\", foo=cRaZyCaSe, bAr=\"inQuotes\"",
);
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.public());
assert_eq!(cc.max_age().unwrap().unwrap(), 10000);
let mut directive_iter = cc.directives.iter();
let first = directive_iter.next().unwrap();
assert_eq!(first.0, &"public");
assert_eq!(first.1.as_ref().unwrap().0, "\"something\"".as_bytes());
let second = directive_iter.next().unwrap();
assert_eq!(second.0, &"max-age");
assert_eq!(second.1.as_ref().unwrap().0, "\"10000\"".as_bytes());
// values are still stored with casing
let third = directive_iter.next().unwrap();
assert_eq!(third.0, &"foo");
assert_eq!(third.1.as_ref().unwrap().0, "cRaZyCaSe".as_bytes());
let fourth = directive_iter.next().unwrap();
assert_eq!(fourth.0, &"bar");
assert_eq!(fourth.1.as_ref().unwrap().0, "\"inQuotes\"".as_bytes());
assert!(directive_iter.next().is_none());
}
#[test]
fn test_non_ascii() {
let resp = build_response(CACHE_CONTROL, "püblic=💖, max-age=\"💯\"");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
// Not considered valid registered directive keys / values
assert!(!cc.public());
assert_eq!(
cc.max_age().unwrap_err().context.unwrap().to_string(),
"could not parse value as u32"
);
let mut directive_iter = cc.directives.iter();
let first = directive_iter.next().unwrap();
assert_eq!(first.0, &"püblic");
assert_eq!(first.1.as_ref().unwrap().0, "💖".as_bytes());
let second = directive_iter.next().unwrap();
assert_eq!(second.0, &"max-age");
assert_eq!(second.1.as_ref().unwrap().0, "\"💯\"".as_bytes());
assert!(directive_iter.next().is_none());
}
#[test]
fn test_non_utf8_key() {
let mut resp = response::Builder::new().body(()).unwrap();
resp.headers_mut().insert(
CACHE_CONTROL,
HeaderValue::from_bytes(b"bar\xFF=\"baz\", a=b").unwrap(),
);
let (parts, _) = resp.into_parts();
let cc = CacheControl::from_resp_headers(&parts).unwrap();
// invalid bytes for key
let mut directive_iter = cc.directives.iter();
let first = directive_iter.next().unwrap();
assert_eq!(first.0, &"a");
assert_eq!(first.1.as_ref().unwrap().0, "b".as_bytes());
assert!(directive_iter.next().is_none());
}
#[test]
fn test_non_utf8_value() {
// RFC 7230: 0xFF is part of obs-text and is officially considered a valid octet in quoted-strings
let mut resp = response::Builder::new().body(()).unwrap();
resp.headers_mut().insert(
CACHE_CONTROL,
HeaderValue::from_bytes(b"max-age=ba\xFFr, bar=\"baz\xFF\", a=b").unwrap(),
);
let (parts, _) = resp.into_parts();
let cc = CacheControl::from_resp_headers(&parts).unwrap();
assert_eq!(
cc.max_age().unwrap_err().context.unwrap().to_string(),
"could not parse value as utf8"
);
let mut directive_iter = cc.directives.iter();
let first = directive_iter.next().unwrap();
assert_eq!(first.0, &"max-age");
assert_eq!(first.1.as_ref().unwrap().0, b"ba\xFFr");
let second = directive_iter.next().unwrap();
assert_eq!(second.0, &"bar");
assert_eq!(second.1.as_ref().unwrap().0, b"\"baz\xFF\"");
let third = directive_iter.next().unwrap();
assert_eq!(third.0, &"a");
assert_eq!(third.1.as_ref().unwrap().0, "b".as_bytes());
assert!(directive_iter.next().is_none());
}
#[test]
fn test_age_overflow() {
let resp = build_response(
CACHE_CONTROL,
"max-age=-99999999999999999999999999, s-maxage=99999999999999999999999999",
);
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(
cc.s_maxage().unwrap().unwrap(),
DELTA_SECONDS_OVERFLOW_VALUE
);
// negative ages still result in errors even with overflow handling
assert_eq!(
cc.max_age().unwrap_err().context.unwrap().to_string(),
"could not parse value as u32"
);
}
#[test]
fn test_fresh_sec() {
let resp = build_response(CACHE_CONTROL, "");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.fresh_sec().is_none());
let resp = build_response(CACHE_CONTROL, "max-age=12345");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.fresh_sec().unwrap(), 12345);
let resp = build_response(CACHE_CONTROL, "max-age=99999,s-maxage=123");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
// prefer s-maxage over max-age
assert_eq!(cc.fresh_sec().unwrap(), 123);
}
#[test]
fn test_cacheability() {
let resp = build_response(CACHE_CONTROL, "");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::Default);
// uncacheable
let resp = build_response(CACHE_CONTROL, "private, max-age=12345");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::No);
let resp = build_response(CACHE_CONTROL, "no-store, max-age=12345");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::No);
// cacheable
let resp = build_response(CACHE_CONTROL, "public");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::Yes);
let resp = build_response(CACHE_CONTROL, "max-age=0");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::Yes);
}
#[test]
fn test_no_cache() {
let resp = build_response(CACHE_CONTROL, "no-cache, max-age=12345");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.is_cacheable(), Cacheable::Yes);
assert_eq!(cc.fresh_sec().unwrap(), 0);
}
#[test]
fn test_no_cache_field_names() {
let resp = build_response(CACHE_CONTROL, "no-cache=\"set-cookie\", max-age=12345");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(!cc.private());
assert_eq!(cc.is_cacheable(), Cacheable::Yes);
assert_eq!(cc.fresh_sec().unwrap(), 12345);
let mut field_names = cc.no_cache_field_names().unwrap();
assert_eq!(
str::from_utf8(field_names.next().unwrap()).unwrap(),
"set-cookie"
);
assert!(field_names.next().is_none());
let mut resp = response::Builder::new().body(()).unwrap();
resp.headers_mut().insert(
CACHE_CONTROL,
HeaderValue::from_bytes(
b"private=\"\", no-cache=\"a\xFF, set-cookie, Baz\x09 , c,d ,, \"",
)
.unwrap(),
);
let (parts, _) = resp.into_parts();
let cc = CacheControl::from_resp_headers(&parts).unwrap();
let mut field_names = cc.private_field_names().unwrap();
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "");
assert!(field_names.next().is_none());
let mut field_names = cc.no_cache_field_names().unwrap();
assert!(str::from_utf8(field_names.next().unwrap()).is_err());
assert_eq!(
str::from_utf8(field_names.next().unwrap()).unwrap(),
"set-cookie"
);
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "Baz");
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "c");
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "d");
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "");
assert_eq!(str::from_utf8(field_names.next().unwrap()).unwrap(), "");
assert!(field_names.next().is_none());
}
#[test]
fn test_strip_private_headers() {
let mut resp = ResponseHeader::build(200, None).unwrap();
resp.append_header(
CACHE_CONTROL,
"no-cache=\"x-private-header\", max-age=12345",
)
.unwrap();
resp.append_header("X-Private-Header", "dropped").unwrap();
let cc = CacheControl::from_resp_headers(&resp).unwrap();
cc.strip_private_headers(&mut resp);
assert!(!resp.headers.contains_key("X-Private-Header"));
}
#[test]
fn test_stale_while_revalidate() {
let resp = build_response(CACHE_CONTROL, "max-age=12345, stale-while-revalidate=5");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.stale_while_revalidate().unwrap().unwrap(), 5);
assert_eq!(cc.serve_stale_while_revalidate_sec().unwrap(), 5);
assert!(cc.serve_stale_if_error_sec().is_none());
}
#[test]
fn test_stale_if_error() {
let resp = build_response(CACHE_CONTROL, "max-age=12345, stale-if-error=3600");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.stale_if_error().unwrap().unwrap(), 3600);
assert_eq!(cc.serve_stale_if_error_sec().unwrap(), 3600);
assert!(cc.serve_stale_while_revalidate_sec().is_none());
}
#[test]
fn test_must_revalidate() {
let resp = build_response(
CACHE_CONTROL,
"max-age=12345, stale-while-revalidate=60, stale-if-error=30, must-revalidate",
);
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.must_revalidate());
assert_eq!(cc.stale_while_revalidate().unwrap().unwrap(), 60);
assert_eq!(cc.stale_if_error().unwrap().unwrap(), 30);
assert_eq!(cc.serve_stale_while_revalidate_sec().unwrap(), 0);
assert_eq!(cc.serve_stale_if_error_sec().unwrap(), 0);
}
#[test]
fn test_proxy_revalidate() {
let resp = build_response(
CACHE_CONTROL,
"max-age=12345, stale-while-revalidate=60, stale-if-error=30, proxy-revalidate",
);
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.proxy_revalidate());
assert_eq!(cc.stale_while_revalidate().unwrap().unwrap(), 60);
assert_eq!(cc.stale_if_error().unwrap().unwrap(), 30);
assert_eq!(cc.serve_stale_while_revalidate_sec().unwrap(), 0);
assert_eq!(cc.serve_stale_if_error_sec().unwrap(), 0);
}
#[test]
fn test_s_maxage_stale() {
let resp = build_response(
CACHE_CONTROL,
"s-maxage=0, stale-while-revalidate=60, stale-if-error=30",
);
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert_eq!(cc.stale_while_revalidate().unwrap().unwrap(), 60);
assert_eq!(cc.stale_if_error().unwrap().unwrap(), 30);
assert_eq!(cc.serve_stale_while_revalidate_sec().unwrap(), 0);
assert_eq!(cc.serve_stale_if_error_sec().unwrap(), 0);
}
#[test]
fn test_authorized_request() {
let resp = build_response(CACHE_CONTROL, "max-age=10");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(!cc.allow_caching_authorized_req());
let resp = build_response(CACHE_CONTROL, "s-maxage=10");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.allow_caching_authorized_req());
let resp = build_response(CACHE_CONTROL, "public");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.allow_caching_authorized_req());
let resp = build_response(CACHE_CONTROL, "must-revalidate, max-age=0");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(cc.allow_caching_authorized_req());
let resp = build_response(CACHE_CONTROL, "");
let cc = CacheControl::from_resp_headers(&resp).unwrap();
assert!(!cc.allow_caching_authorized_req());
}
fn build_request(cc_key: HeaderName, cc_value: &str) -> request::Parts {
let (parts, _) = request::Builder::new()
.header(cc_key, cc_value)
.body(())
.unwrap()
.into_parts();
parts
}
#[test]
fn test_request_only_if_cached() {
let req = build_request(CACHE_CONTROL, "only-if-cached=1");
let cc = CacheControl::from_req_headers(&req).unwrap();
assert!(cc.only_if_cached())
}
}

View file

@ -0,0 +1,431 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A shared LRU cache manager
use super::EvictionManager;
use crate::key::CompactCacheKey;
use async_trait::async_trait;
use pingora_error::{BError, ErrorType::*, OrErr, Result};
use pingora_lru::Lru;
use serde::de::SeqAccess;
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::hash::{Hash, Hasher};
use std::io::prelude::*;
use std::path::Path;
use std::time::SystemTime;
/// A shared LRU cache manager designed to manage a large volume of assets.
///
/// - Space optimized in-memory LRU (see [pingora_lru]).
/// - Instead of a single giant LRU, this struct shards the assets into `N` independent LRUs.
/// This allows [EvictionManager::save()] not to lock the entire cache mananger while performing
/// serialization.
pub struct Manager<const N: usize>(Lru<CompactCacheKey, N>);
#[derive(Debug, Serialize, Deserialize)]
struct SerdeHelperNode(CompactCacheKey, usize);
impl<const N: usize> Manager<N> {
/// Create a [Manager] with the given size limit and estimated per shard capacity.
///
/// The `capacity` is for preallocating to avoid reallocation cost when the LRU grows.
pub fn with_capacity(limit: usize, capacity: usize) -> Self {
Manager(Lru::with_capacity(limit, capacity))
}
/// Serialize the given shard
pub fn serialize_shard(&self, shard: usize) -> Result<Vec<u8>> {
use rmp_serde::encode::Serializer;
use serde::ser::SerializeSeq;
use serde::ser::Serializer as _;
assert!(shard < N);
// NOTE: This could use a lot memory to buffer the serialized data in memory
// NOTE: This for loop could lock the LRU for too long
let mut nodes = Vec::with_capacity(self.0.shard_len(shard));
self.0.iter_for_each(shard, |(node, size)| {
nodes.push(SerdeHelperNode(node.clone(), size));
});
let mut ser = Serializer::new(vec![]);
let mut seq = ser
.serialize_seq(Some(self.0.shard_len(shard)))
.or_err(InternalError, "fail to serialize node")?;
for node in nodes {
seq.serialize_element(&node).unwrap(); // write to vec, safe
}
seq.end().or_err(InternalError, "when serializing LRU")?;
Ok(ser.into_inner())
}
/// Deserialize a shard
///
/// Shard number is not needed because the key itself will hash to the correct shard.
pub fn deserialize_shard(&self, buf: &[u8]) -> Result<()> {
use rmp_serde::decode::Deserializer;
use serde::de::Deserializer as _;
let mut de = Deserializer::new(buf);
let visitor = InsertToManager { lru: self };
de.deserialize_seq(visitor)
.or_err(InternalError, "when deserializing LRU")?;
Ok(())
}
}
struct InsertToManager<'a, const N: usize> {
lru: &'a Manager<N>,
}
impl<'de, 'a, const N: usize> serde::de::Visitor<'de> for InsertToManager<'a, N> {
type Value = ();
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("array of lru nodes")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
while let Some(node) = seq.next_element::<SerdeHelperNode>()? {
let key = u64key(&node.0);
self.lru.0.insert_tail(key, node.0, node.1); // insert in the back
}
Ok(())
}
}
#[inline]
fn u64key(key: &CompactCacheKey) -> u64 {
// note that std hash is not uniform, I'm not sure if ahash is also the case
let mut hasher = ahash::AHasher::default();
key.hash(&mut hasher);
hasher.finish()
}
const FILE_NAME: &str = "lru.data";
#[inline]
fn err_str_path(s: &str, path: &Path) -> String {
format!("{s} {}", path.display())
}
#[async_trait]
impl<const N: usize> EvictionManager for Manager<N> {
fn total_size(&self) -> usize {
self.0.weight()
}
fn total_items(&self) -> usize {
self.0.len()
}
fn evicted_size(&self) -> usize {
self.0.evicted_weight()
}
fn evicted_items(&self) -> usize {
self.0.evicted_len()
}
fn admit(
&self,
item: CompactCacheKey,
size: usize,
_fresh_until: SystemTime,
) -> Vec<CompactCacheKey> {
let key = u64key(&item);
self.0.admit(key, item, size);
self.0
.evict_to_limit()
.into_iter()
.map(|(key, _weight)| key)
.collect()
}
fn remove(&self, item: &CompactCacheKey) {
let key = u64key(item);
self.0.remove(key);
}
fn access(&self, item: &CompactCacheKey, size: usize, _fresh_until: SystemTime) -> bool {
let key = u64key(item);
if !self.0.promote(key) {
self.0.admit(key, item.clone(), size);
false
} else {
true
}
}
fn peek(&self, item: &CompactCacheKey) -> bool {
let key = u64key(item);
self.0.peek(key)
}
async fn save(&self, dir_path: &str) -> Result<()> {
let dir_path_str = dir_path.to_owned();
tokio::task::spawn_blocking(move || {
let dir_path = Path::new(&dir_path_str);
std::fs::create_dir_all(dir_path)
.or_err_with(InternalError, || err_str_path("fail to create", dir_path))
})
.await
.or_err(InternalError, "async blocking IO failure")??;
for i in 0..N {
let data = self.serialize_shard(i)?;
let dir_path = dir_path.to_owned();
tokio::task::spawn_blocking(move || {
let file_path = Path::new(&dir_path).join(format!("{}.{i}", FILE_NAME));
let mut file = File::create(&file_path)
.or_err_with(InternalError, || err_str_path("fail to create", &file_path))?;
file.write_all(&data).or_err_with(InternalError, || {
err_str_path("fail to write to", &file_path)
})
})
.await
.or_err(InternalError, "async blocking IO failure")??;
}
Ok(())
}
async fn load(&self, dir_path: &str) -> Result<()> {
// TODO: check the saved shards so that we load all the save files
for i in 0..N {
let dir_path = dir_path.to_owned();
let data = tokio::task::spawn_blocking(move || {
let file_path = Path::new(&dir_path).join(format!("{}.{i}", FILE_NAME));
let mut file = File::open(&file_path)
.or_err_with(InternalError, || err_str_path("fail to open", &file_path))?;
let mut buffer = Vec::with_capacity(8192);
file.read_to_end(&mut buffer)
.or_err_with(InternalError, || {
err_str_path("fail to write to", &file_path)
})?;
Ok::<Vec<u8>, BError>(buffer)
})
.await
.or_err(InternalError, "async blocking IO failure")??;
self.deserialize_shard(&data)?;
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::CacheKey;
use EvictionManager;
// we use shard (N) = 1 for eviction consistency in all tests
#[test]
fn test_admission() {
let lru = Manager::<1>::with_capacity(4, 10);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); // unused value as a placeholder
let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
// lru si full (4) now
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
// need to reduce used by at least 2, both key1 and key2 are evicted to make room for 3
assert_eq!(v.len(), 2);
assert_eq!(v[0], key1);
assert_eq!(v[1], key2);
}
#[test]
fn test_access() {
let lru = Manager::<1>::with_capacity(4, 10);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); // unused value as a placeholder
let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
// lru is full (4) now
// make key1 most recently used
lru.access(&key1, 1, until);
assert_eq!(v.len(), 0);
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
assert_eq!(v.len(), 1);
assert_eq!(v[0], key2);
}
#[test]
fn test_remove() {
let lru = Manager::<1>::with_capacity(4, 10);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); // unused value as a placeholder
let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
// lru is full (4) now
// remove key1
lru.remove(&key1);
// key2 is the least recently used one now
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
assert_eq!(v.len(), 1);
assert_eq!(v[0], key2);
}
#[test]
fn test_access_add() {
let lru = Manager::<1>::with_capacity(4, 10);
let until = SystemTime::now(); // unused value as a placeholder
let key1 = CacheKey::new("", "a", "1").to_compact();
lru.access(&key1, 1, until);
let key2 = CacheKey::new("", "b", "1").to_compact();
lru.access(&key2, 2, until);
let key3 = CacheKey::new("", "c", "1").to_compact();
lru.access(&key3, 2, until);
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
// need to reduce used by at least 2, both key1 and key2 are evicted to make room for 3
assert_eq!(v.len(), 2);
assert_eq!(v[0], key1);
assert_eq!(v[1], key2);
}
#[test]
fn test_admit_update() {
let lru = Manager::<1>::with_capacity(4, 10);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); // unused value as a placeholder
let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
// lru is full (4) now
// update key2 to reduce its size by 1
let v = lru.admit(key2, 1, until);
assert_eq!(v.len(), 0);
// lru is not full anymore
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4.clone(), 1, until);
assert_eq!(v.len(), 0);
// make key4 larger
let v = lru.admit(key4, 2, until);
// need to evict now
assert_eq!(v.len(), 1);
assert_eq!(v[0], key1);
}
#[test]
fn test_peek() {
let lru = Manager::<1>::with_capacity(4, 10);
let until = SystemTime::now(); // unused value as a placeholder
let key1 = CacheKey::new("", "a", "1").to_compact();
lru.access(&key1, 1, until);
let key2 = CacheKey::new("", "b", "1").to_compact();
lru.access(&key2, 2, until);
assert!(lru.peek(&key1));
assert!(lru.peek(&key2));
}
#[test]
fn test_serde() {
let lru = Manager::<1>::with_capacity(4, 10);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); // unused value as a placeholder
let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
// lru is full (4) now
// make key1 most recently used
lru.access(&key1, 1, until);
assert_eq!(v.len(), 0);
// load lru2 with lru's data
let ser = lru.serialize_shard(0).unwrap();
let lru2 = Manager::<1>::with_capacity(4, 10);
lru2.deserialize_shard(&ser).unwrap();
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru2.admit(key4, 2, until);
assert_eq!(v.len(), 1);
assert_eq!(v[0], key2);
}
#[tokio::test]
async fn test_save_to_disk() {
let until = SystemTime::now(); // unused value as a placeholder
let lru = Manager::<2>::with_capacity(10, 10);
lru.admit(CacheKey::new("", "a", "1").to_compact(), 1, until);
lru.admit(CacheKey::new("", "b", "1").to_compact(), 2, until);
lru.admit(CacheKey::new("", "c", "1").to_compact(), 1, until);
lru.admit(CacheKey::new("", "d", "1").to_compact(), 1, until);
lru.admit(CacheKey::new("", "e", "1").to_compact(), 2, until);
lru.admit(CacheKey::new("", "f", "1").to_compact(), 1, until);
// load lru2 with lru's data
lru.save("/tmp/test_lru_save").await.unwrap();
let lru2 = Manager::<2>::with_capacity(4, 10);
lru2.load("/tmp/test_lru_save").await.unwrap();
let ser0 = lru.serialize_shard(0).unwrap();
let ser1 = lru.serialize_shard(1).unwrap();
assert_eq!(ser0, lru2.serialize_shard(0).unwrap());
assert_eq!(ser1, lru2.serialize_shard(1).unwrap());
}
}

View file

@ -0,0 +1,89 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Cache eviction module
use crate::key::CompactCacheKey;
use async_trait::async_trait;
use pingora_error::Result;
use std::time::SystemTime;
pub mod lru;
pub mod simple_lru;
/// The trait that a cache eviction algorithm needs to implement
///
/// NOTE: these trait methods require &self not &mut self, which means concurrency should
/// be handled the implementations internally.
#[async_trait]
pub trait EvictionManager {
/// Total size of the cache in bytes tracked by this eviction mananger
fn total_size(&self) -> usize;
/// Number of assets tracked by this eviction mananger
fn total_items(&self) -> usize;
/// Number of bytes that are already evicted
///
/// The accumulated number is returned to play well with Prometheus counter metric type.
fn evicted_size(&self) -> usize;
/// Number of assets that are already evicted
///
/// The accumulated number is returned to play well with Prometheus counter metric type.
fn evicted_items(&self) -> usize;
/// Admit an item
///
/// Return one or more items to evict. The sizes of these items are deducted
/// from the total size already. The caller needs to make sure that these assets are actually
/// removed from the storage.
///
/// If the item is already admitted, A. update its freshness; B. if the new size is larger than the
/// existing one, Some(_) might be returned for the caller to evict.
fn admit(
&self,
item: CompactCacheKey,
size: usize,
fresh_until: SystemTime,
) -> Vec<CompactCacheKey>;
/// Remove an item from the eviction manager.
///
/// The size of the item will be deducted.
fn remove(&self, item: &CompactCacheKey);
/// Access an item that should already be in cache.
///
/// If the item is not tracked by this [EvictionManager], track it but no eviction will happen.
///
/// The call used for asking the eviction manager to track the assets that are already admitted
/// in the cache storage system.
fn access(&self, item: &CompactCacheKey, size: usize, fresh_until: SystemTime) -> bool;
/// Peek into the manager to see if the item is already tracked by the system
///
/// This function should have no side-effect on the asset itself. For example, for LRU, this
/// method shouldn't change the popularity of the asset being peeked.
fn peek(&self, item: &CompactCacheKey) -> bool;
/// Serialize to save the state of this eviction mananger to disk
///
/// This function is for preserving the eviction manager's state across server restarts.
///
/// `dir_path` define the directory on disk that the data should use.
// dir_path is &str no AsRef<Path> so that trait objects can be used
async fn save(&self, dir_path: &str) -> Result<()>;
/// The counterpart of [Self::save()].
async fn load(&self, dir_path: &str) -> Result<()>;
}

View file

@ -0,0 +1,445 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A simple LRU cache manager built on top of the `lru` crate
use super::EvictionManager;
use crate::key::CompactCacheKey;
use async_trait::async_trait;
use lru::LruCache;
use parking_lot::RwLock;
use pingora_error::{BError, ErrorType::*, OrErr, Result};
use serde::de::SeqAccess;
use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::fs::File;
use std::hash::{Hash, Hasher};
use std::io::prelude::*;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::SystemTime;
#[derive(Debug, Deserialize, Serialize)]
struct Node {
key: CompactCacheKey,
size: usize,
}
/// A simple LRU eviction manager
///
/// The implementation is not optimized. All operation require global locks.
pub struct Manager {
lru: RwLock<LruCache<u64, Node>>,
limit: usize,
used: AtomicUsize,
items: AtomicUsize,
evicted_size: AtomicUsize,
evicted_items: AtomicUsize,
}
impl Manager {
/// Create a new [Manager] with the given total size limit `limit`.
pub fn new(limit: usize) -> Self {
Manager {
lru: RwLock::new(LruCache::unbounded()),
limit,
used: AtomicUsize::new(0),
items: AtomicUsize::new(0),
evicted_size: AtomicUsize::new(0),
evicted_items: AtomicUsize::new(0),
}
}
fn insert(&self, hash_key: u64, node: CompactCacheKey, size: usize, reverse: bool) {
use std::cmp::Ordering::*;
let node = Node { key: node, size };
let old = {
let mut lru = self.lru.write();
let old = lru.push(hash_key, node);
if reverse && old.is_none() {
lru.demote(&hash_key);
}
old
};
if let Some(old) = old {
// replacing a node, just need to update used size
match size.cmp(&old.1.size) {
Greater => self.used.fetch_add(size - old.1.size, Ordering::Relaxed),
Less => self.used.fetch_sub(old.1.size - size, Ordering::Relaxed),
Equal => 0, // same size, update nothing, use 0 to match other arms' type
};
} else {
self.used.fetch_add(size, Ordering::Relaxed);
self.items.fetch_add(1, Ordering::Relaxed);
}
}
// evict items until the used capacity is below limit
fn evict(&self) -> Vec<CompactCacheKey> {
if self.used.load(Ordering::Relaxed) <= self.limit {
return vec![];
}
let mut to_evict = Vec::with_capacity(1); // we will at least pop 1 item
while self.used.load(Ordering::Relaxed) > self.limit {
if let Some((_, node)) = self.lru.write().pop_lru() {
self.used.fetch_sub(node.size, Ordering::Relaxed);
self.items.fetch_sub(1, Ordering::Relaxed);
self.evicted_size.fetch_add(node.size, Ordering::Relaxed);
self.evicted_items.fetch_add(1, Ordering::Relaxed);
to_evict.push(node.key);
} else {
// lru empty
return to_evict;
}
}
to_evict
}
// This could use a lot memory to buffer the serialized data in memory and could lock the LRU
// for too long
fn serialize(&self) -> Result<Vec<u8>> {
use rmp_serde::encode::Serializer;
use serde::ser::SerializeSeq;
use serde::ser::Serializer as _;
// NOTE: This could use a lot memory to buffer the serialized data in memory
let mut ser = Serializer::new(vec![]);
// NOTE: This long for loop could lock the LRU for too long
let lru = self.lru.read();
let mut seq = ser
.serialize_seq(Some(lru.len()))
.or_err(InternalError, "fail to serialize node")?;
for item in lru.iter() {
seq.serialize_element(item.1).unwrap(); // write to vec, safe
}
seq.end().or_err(InternalError, "when serializing LRU")?;
Ok(ser.into_inner())
}
fn deserialize(&self, buf: &[u8]) -> Result<()> {
use rmp_serde::decode::Deserializer;
use serde::de::Deserializer as _;
let mut de = Deserializer::new(buf);
let visitor = InsertToManager { lru: self };
de.deserialize_seq(visitor)
.or_err(InternalError, "when deserializing LRU")?;
Ok(())
}
}
struct InsertToManager<'a> {
lru: &'a Manager,
}
impl<'de, 'a> serde::de::Visitor<'de> for InsertToManager<'a> {
type Value = ();
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("array of lru nodes")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
while let Some(node) = seq.next_element::<Node>()? {
let key = u64key(&node.key);
self.lru.insert(key, node.key, node.size, true); // insert in the back
}
Ok(())
}
}
#[inline]
fn u64key(key: &CompactCacheKey) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
const FILE_NAME: &str = "simple_lru.data";
#[async_trait]
impl EvictionManager for Manager {
fn total_size(&self) -> usize {
self.used.load(Ordering::Relaxed)
}
fn total_items(&self) -> usize {
self.items.load(Ordering::Relaxed)
}
fn evicted_size(&self) -> usize {
self.evicted_size.load(Ordering::Relaxed)
}
fn evicted_items(&self) -> usize {
self.evicted_items.load(Ordering::Relaxed)
}
fn admit(
&self,
item: CompactCacheKey,
size: usize,
_fresh_until: SystemTime,
) -> Vec<CompactCacheKey> {
let key = u64key(&item);
self.insert(key, item, size, false);
self.evict()
}
fn remove(&self, item: &CompactCacheKey) {
let key = u64key(item);
let node = self.lru.write().pop(&key);
if let Some(n) = node {
self.used.fetch_sub(n.size, Ordering::Relaxed);
self.items.fetch_sub(1, Ordering::Relaxed);
}
}
fn access(&self, item: &CompactCacheKey, size: usize, _fresh_until: SystemTime) -> bool {
let key = u64key(item);
if self.lru.write().get(&key).is_none() {
self.insert(key, item.clone(), size, false);
false
} else {
true
}
}
fn peek(&self, item: &CompactCacheKey) -> bool {
let key = u64key(item);
self.lru.read().peek(&key).is_some()
}
async fn save(&self, dir_path: &str) -> Result<()> {
let data = self.serialize()?;
let dir_path = dir_path.to_owned();
tokio::task::spawn_blocking(move || {
let dir_path = Path::new(&dir_path);
std::fs::create_dir_all(dir_path).or_err(InternalError, "fail to create {dir_path}")?;
let file_path = dir_path.join(FILE_NAME);
let mut file =
File::create(file_path).or_err(InternalError, "fail to create {file_path}")?;
file.write_all(&data)
.or_err(InternalError, "fail to write to {file_path}")
})
.await
.or_err(InternalError, "async blocking IO failure")?
}
async fn load(&self, dir_path: &str) -> Result<()> {
let dir_path = dir_path.to_owned();
let data = tokio::task::spawn_blocking(move || {
let file_path = Path::new(&dir_path).join(FILE_NAME);
let mut file =
File::open(file_path).or_err(InternalError, "fail to open {file_path}")?;
let mut buffer = Vec::with_capacity(8192);
file.read_to_end(&mut buffer)
.or_err(InternalError, "fail to write to {file_path}")?;
Ok::<Vec<u8>, BError>(buffer)
})
.await
.or_err(InternalError, "async blocking IO failure")??;
self.deserialize(&data)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::CacheKey;
#[test]
fn test_admission() {
let lru = Manager::new(4);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); // unused value as a placeholder
let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
// lru si full (4) now
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
// need to reduce used by at least 2, both key1 and key2 are evicted to make room for 3
assert_eq!(v.len(), 2);
assert_eq!(v[0], key1);
assert_eq!(v[1], key2);
}
#[test]
fn test_access() {
let lru = Manager::new(4);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); // unused value as a placeholder
let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
// lru is full (4) now
// make key1 most recently used
lru.access(&key1, 1, until);
assert_eq!(v.len(), 0);
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
assert_eq!(v.len(), 1);
assert_eq!(v[0], key2);
}
#[test]
fn test_remove() {
let lru = Manager::new(4);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); // unused value as a placeholder
let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
// lru is full (4) now
// remove key1
lru.remove(&key1);
// key2 is the least recently used one now
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
assert_eq!(v.len(), 1);
assert_eq!(v[0], key2);
}
#[test]
fn test_access_add() {
let lru = Manager::new(4);
let until = SystemTime::now(); // unused value as a placeholder
let key1 = CacheKey::new("", "a", "1").to_compact();
lru.access(&key1, 1, until);
let key2 = CacheKey::new("", "b", "1").to_compact();
lru.access(&key2, 2, until);
let key3 = CacheKey::new("", "c", "1").to_compact();
lru.access(&key3, 2, until);
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
// need to reduce used by at least 2, both key1 and key2 are evicted to make room for 3
assert_eq!(v.len(), 2);
assert_eq!(v[0], key1);
assert_eq!(v[1], key2);
}
#[test]
fn test_admit_update() {
let lru = Manager::new(4);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); // unused value as a placeholder
let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
// lru is full (4) now
// update key2 to reduce its size by 1
let v = lru.admit(key2, 1, until);
assert_eq!(v.len(), 0);
// lru is not full anymore
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4.clone(), 1, until);
assert_eq!(v.len(), 0);
// make key4 larger
let v = lru.admit(key4, 2, until);
// need to evict now
assert_eq!(v.len(), 1);
assert_eq!(v[0], key1);
}
#[test]
fn test_serde() {
let lru = Manager::new(4);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); // unused value as a placeholder
let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
// lru is full (4) now
// make key1 most recently used
lru.access(&key1, 1, until);
assert_eq!(v.len(), 0);
// load lru2 with lru's data
let ser = lru.serialize().unwrap();
let lru2 = Manager::new(4);
lru2.deserialize(&ser).unwrap();
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru2.admit(key4, 2, until);
assert_eq!(v.len(), 1);
assert_eq!(v[0], key2);
}
#[tokio::test]
async fn test_save_to_disk() {
let lru = Manager::new(4);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); // unused value as a placeholder
let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
// lru is full (4) now
// make key1 most recently used
lru.access(&key1, 1, until);
assert_eq!(v.len(), 0);
// load lru2 with lru's data
lru.save("/tmp/test_simple_lru_save").await.unwrap();
let lru2 = Manager::new(4);
lru2.load("/tmp/test_simple_lru_save").await.unwrap();
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru2.admit(key4, 2, until);
assert_eq!(v.len(), 1);
assert_eq!(v[0], key2);
}
}

View file

@ -0,0 +1,673 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Utility functions to help process HTTP headers for caching
use super::*;
use crate::cache_control::{CacheControl, Cacheable, InterpretCacheControl};
use crate::{RespCacheable, RespCacheable::*};
use http::{header, HeaderValue};
use httpdate::HttpDate;
use log::warn;
use pingora_http::{RequestHeader, ResponseHeader};
/// Decide if the request can be cacheable
pub fn request_cacheable(req_header: &ReqHeader) -> bool {
// TODO: the check is incomplete
matches!(req_header.method, Method::GET | Method::HEAD)
}
/// Decide if the response is cacheable.
///
/// `cache_control` is the parsed [CacheControl] from the response header. It is an standalone
/// argument so that caller has the flexibility to choose to use, change or ignore it.
// TODO: vary processing
pub fn resp_cacheable(
cache_control: Option<&CacheControl>,
resp_header: &ResponseHeader,
authorization_present: bool,
defaults: &CacheMetaDefaults,
) -> RespCacheable {
let now = SystemTime::now();
let expire_time = calculate_fresh_until(
now,
cache_control,
resp_header,
authorization_present,
defaults,
);
if let Some(fresh_until) = expire_time {
let (stale_while_revalidate_sec, stale_if_error_sec) =
calculate_serve_stale_sec(cache_control, defaults);
let mut cloned_header = resp_header.clone();
if let Some(cc) = cache_control {
cc.strip_private_headers(&mut cloned_header);
}
return Cacheable(CacheMeta::new(
fresh_until,
now,
stale_while_revalidate_sec,
stale_if_error_sec,
cloned_header,
));
}
Uncacheable(NoCacheReason::OriginNotCache)
}
/// Calculate the [SystemTime] at which the asset expires
///
/// Return None when not cacheable.
pub fn calculate_fresh_until(
now: SystemTime,
cache_control: Option<&CacheControl>,
resp_header: &RespHeader,
authorization_present: bool,
defaults: &CacheMetaDefaults,
) -> Option<SystemTime> {
fn freshness_ttl_to_time(now: SystemTime, fresh_sec: u32) -> Option<SystemTime> {
if fresh_sec == 0 {
// ensure that the response is treated as stale
now.checked_sub(Duration::from_secs(1))
} else {
now.checked_add(Duration::from_secs(fresh_sec.into()))
}
}
// A request with Authorization is normally not cacheable, unless Cache-Control allows it
if authorization_present {
let uncacheable = cache_control
.as_ref()
.map_or(true, |cc| !cc.allow_caching_authorized_req());
if uncacheable {
return None;
}
}
let uncacheable = cache_control
.as_ref()
.map_or(false, |cc| cc.is_cacheable() == Cacheable::No);
if uncacheable {
return None;
}
// For TTL check cache-control first, then expires header, then defaults
cache_control
.and_then(|cc| {
cc.fresh_sec()
.and_then(|ttl| freshness_ttl_to_time(now, ttl))
})
.or_else(|| calculate_expires_header_time(resp_header))
.or_else(|| {
defaults
.fresh_sec(resp_header.status)
.and_then(|ttl| freshness_ttl_to_time(now, ttl))
})
}
/// Calculate the expire time from the `Expires` header only
pub fn calculate_expires_header_time(resp_header: &RespHeader) -> Option<SystemTime> {
// according to RFC 7234:
// https://datatracker.ietf.org/doc/html/rfc7234#section-4.2.1
// - treat multiple expires headers as invalid
// https://datatracker.ietf.org/doc/html/rfc7234#section-5.3
// - "MUST interpret invalid date formats... as representing a time in the past"
fn parse_expires_value(expires_value: &HeaderValue) -> Option<SystemTime> {
let expires = expires_value.to_str().ok()?;
Some(SystemTime::from(
expires
.parse::<HttpDate>()
.map_err(|e| warn!("Invalid HttpDate in Expires: {}, error: {}", expires, e))
.ok()?,
))
}
let mut expires_iter = resp_header.headers.get_all("expires").iter();
let expires_header = expires_iter.next();
if expires_header.is_none() || expires_iter.next().is_some() {
return None;
}
parse_expires_value(expires_header.unwrap()).or(Some(SystemTime::UNIX_EPOCH))
}
/// Calculates stale-while-revalidate and stale-if-error seconds from Cache-Control or the [CacheMetaDefaults].
pub fn calculate_serve_stale_sec(
cache_control: Option<&impl InterpretCacheControl>,
defaults: &CacheMetaDefaults,
) -> (u32, u32) {
let serve_stale_while_revalidate_sec = cache_control
.and_then(|cc| cc.serve_stale_while_revalidate_sec())
.unwrap_or_else(|| defaults.serve_stale_while_revalidate_sec());
let serve_stale_if_error_sec = cache_control
.and_then(|cc| cc.serve_stale_if_error_sec())
.unwrap_or_else(|| defaults.serve_stale_if_error_sec());
(serve_stale_while_revalidate_sec, serve_stale_if_error_sec)
}
/// Filters to run when sending requests to upstream
pub mod upstream {
use super::*;
/// Adjust the request header for cacheable requests
///
/// This filter does the following in order to fetch the entire response to cache
/// - Convert HEAD to GET
/// - `If-*` headers are removed
/// - `Range` header is removed
///
/// When `meta` is set, this function will inject `If-modified-since` according to the `Last-Modified` header
/// and inject `If-none-match` according to `Etag` header
pub fn request_filter(req: &mut RequestHeader, meta: Option<&CacheMeta>) -> Result<()> {
// change HEAD to GET, HEAD itself is not semantically cacheable
if req.method == Method::HEAD {
req.set_method(Method::GET);
}
// remove downstream precondition headers https://datatracker.ietf.org/doc/html/rfc7232#section-3
// we'd like to cache the 200 not the 304
req.remove_header(&header::IF_MATCH);
req.remove_header(&header::IF_NONE_MATCH);
req.remove_header(&header::IF_MODIFIED_SINCE);
req.remove_header(&header::IF_UNMODIFIED_SINCE);
// see below range header
req.remove_header(&header::IF_RANGE);
// remove downstream range header as we'd like to cache the entire response (this might change in the future)
req.remove_header(&header::RANGE);
// we have a persumably staled response already, add precondition headers for revalidation
if let Some(m) = meta {
// rfc7232: "SHOULD send both validators in cache validation" but
// there have been weird cases that an origin has matching etag but not Last-Modified
if let Some(since) = m.headers().get(&header::LAST_MODIFIED) {
req.insert_header(header::IF_MODIFIED_SINCE, since).unwrap();
}
if let Some(etag) = m.headers().get(&header::ETAG) {
req.insert_header(header::IF_NONE_MATCH, etag).unwrap();
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::header::{HeaderName, CACHE_CONTROL, EXPIRES, SET_COOKIE};
use http::StatusCode;
use httpdate::fmt_http_date;
fn init_log() {
let _ = env_logger::builder().is_test(true).try_init();
}
const DEFAULTS: CacheMetaDefaults = CacheMetaDefaults::new(
|status| match status {
StatusCode::OK => Some(10),
StatusCode::NOT_FOUND => Some(5),
StatusCode::PARTIAL_CONTENT => None,
_ => Some(1),
},
0,
u32::MAX, /* "infinite" stale-if-error */
);
// Cache nothing, by default
const BYPASS_CACHE_DEFAULTS: CacheMetaDefaults = CacheMetaDefaults::new(|_| None, 0, 0);
fn build_response(status: u16, headers: &[(HeaderName, &str)]) -> ResponseHeader {
let mut header = ResponseHeader::build(status, Some(headers.len())).unwrap();
for (k, v) in headers {
header.append_header(k.to_string(), *v).unwrap();
}
header
}
fn resp_cacheable_wrapper(
resp: &ResponseHeader,
defaults: &CacheMetaDefaults,
authorization_present: bool,
) -> Option<CacheMeta> {
if let Cacheable(meta) = resp_cacheable(
CacheControl::from_resp_headers(resp).as_ref(),
resp,
authorization_present,
defaults,
) {
Some(meta)
} else {
None
}
}
#[test]
fn test_resp_cacheable() {
let meta = resp_cacheable_wrapper(
&build_response(200, &[(CACHE_CONTROL, "max-age=12345")]),
&DEFAULTS,
false,
);
let meta = meta.unwrap();
assert!(meta.is_fresh(SystemTime::now()));
assert!(meta.is_fresh(
SystemTime::now()
.checked_add(Duration::from_secs(12))
.unwrap()
),);
assert!(!meta.is_fresh(
SystemTime::now()
.checked_add(Duration::from_secs(12346))
.unwrap()
));
}
#[test]
fn test_resp_uncacheable_directives() {
let meta = resp_cacheable_wrapper(
&build_response(200, &[(CACHE_CONTROL, "private, max-age=12345")]),
&DEFAULTS,
false,
);
assert!(meta.is_none());
let meta = resp_cacheable_wrapper(
&build_response(200, &[(CACHE_CONTROL, "no-store, max-age=12345")]),
&DEFAULTS,
false,
);
assert!(meta.is_none());
}
#[test]
fn test_resp_cache_authorization() {
let meta = resp_cacheable_wrapper(&build_response(200, &[]), &DEFAULTS, true);
assert!(meta.is_none());
let meta = resp_cacheable_wrapper(
&build_response(200, &[(CACHE_CONTROL, "max-age=10")]),
&DEFAULTS,
true,
);
assert!(meta.is_none());
let meta = resp_cacheable_wrapper(
&build_response(200, &[(CACHE_CONTROL, "s-maxage=10")]),
&DEFAULTS,
true,
);
assert!(meta.unwrap().is_fresh(SystemTime::now()));
let meta = resp_cacheable_wrapper(
&build_response(200, &[(CACHE_CONTROL, "public, max-age=10")]),
&DEFAULTS,
true,
);
assert!(meta.unwrap().is_fresh(SystemTime::now()));
let meta = resp_cacheable_wrapper(
&build_response(200, &[(CACHE_CONTROL, "must-revalidate")]),
&DEFAULTS,
true,
);
assert!(meta.unwrap().is_fresh(SystemTime::now()));
}
#[test]
fn test_resp_zero_max_age() {
let meta = resp_cacheable_wrapper(
&build_response(200, &[(CACHE_CONTROL, "max-age=0, public")]),
&DEFAULTS,
false,
);
// cacheable, but needs revalidation
assert!(!meta.unwrap().is_fresh(SystemTime::now()));
}
#[test]
fn test_resp_expires() {
let five_sec_time = SystemTime::now()
.checked_add(Duration::from_secs(5))
.unwrap();
// future expires is cacheable
let meta = resp_cacheable_wrapper(
&build_response(200, &[(EXPIRES, &fmt_http_date(five_sec_time))]),
&DEFAULTS,
false,
);
let meta = meta.unwrap();
assert!(meta.is_fresh(SystemTime::now()));
assert!(!meta.is_fresh(
SystemTime::now()
.checked_add(Duration::from_secs(6))
.unwrap()
));
// even on default uncacheable statuses
let meta = resp_cacheable_wrapper(
&build_response(206, &[(EXPIRES, &fmt_http_date(five_sec_time))]),
&DEFAULTS,
false,
);
assert!(meta.is_some());
}
#[test]
fn test_resp_past_expires() {
// cacheable, but expired
let meta = resp_cacheable_wrapper(
&build_response(200, &[(EXPIRES, "Fri, 15 May 2015 15:34:21 GMT")]),
&BYPASS_CACHE_DEFAULTS,
false,
);
assert!(!meta.unwrap().is_fresh(SystemTime::now()));
}
#[test]
fn test_resp_nonstandard_expires() {
// init log to allow inspecting warnings
init_log();
// invalid cases, according to parser
// (but should be stale according to RFC)
let meta = resp_cacheable_wrapper(
&build_response(200, &[(EXPIRES, "Mon, 13 Feb 0002 12:00:00 GMT")]),
&BYPASS_CACHE_DEFAULTS,
false,
);
assert!(!meta.unwrap().is_fresh(SystemTime::now()));
let meta = resp_cacheable_wrapper(
&build_response(200, &[(EXPIRES, "Fri, 01 Dec 99999 16:00:00 GMT")]),
&BYPASS_CACHE_DEFAULTS,
false,
);
assert!(!meta.unwrap().is_fresh(SystemTime::now()));
let meta = resp_cacheable_wrapper(
&build_response(200, &[(EXPIRES, "0")]),
&BYPASS_CACHE_DEFAULTS,
false,
);
assert!(!meta.unwrap().is_fresh(SystemTime::now()));
}
#[test]
fn test_resp_multiple_expires() {
let five_sec_time = SystemTime::now()
.checked_add(Duration::from_secs(5))
.unwrap();
let ten_sec_time = SystemTime::now()
.checked_add(Duration::from_secs(10))
.unwrap();
// multiple expires = uncacheable
let meta = resp_cacheable_wrapper(
&build_response(
200,
&[
(EXPIRES, &fmt_http_date(five_sec_time)),
(EXPIRES, &fmt_http_date(ten_sec_time)),
],
),
&BYPASS_CACHE_DEFAULTS,
false,
);
assert!(meta.is_none());
// unless the default is cacheable
let meta = resp_cacheable_wrapper(
&build_response(
200,
&[
(EXPIRES, &fmt_http_date(five_sec_time)),
(EXPIRES, &fmt_http_date(ten_sec_time)),
],
),
&DEFAULTS,
false,
);
assert!(meta.is_some());
}
#[test]
fn test_resp_cache_control_with_expires() {
let five_sec_time = SystemTime::now()
.checked_add(Duration::from_secs(5))
.unwrap();
// cache-control takes precedence over expires
let meta = resp_cacheable_wrapper(
&build_response(
200,
&[
(EXPIRES, &fmt_http_date(five_sec_time)),
(CACHE_CONTROL, "max-age=0"),
],
),
&DEFAULTS,
false,
);
assert!(!meta.unwrap().is_fresh(SystemTime::now()));
}
#[test]
fn test_resp_stale_while_revalidate() {
// respect defaults
let meta = resp_cacheable_wrapper(
&build_response(200, &[(CACHE_CONTROL, "max-age=10")]),
&DEFAULTS,
false,
);
let meta = meta.unwrap();
let eleven_sec_time = SystemTime::now()
.checked_add(Duration::from_secs(11))
.unwrap();
assert!(!meta.is_fresh(eleven_sec_time));
assert!(!meta.serve_stale_while_revalidate(SystemTime::now()));
assert!(!meta.serve_stale_while_revalidate(eleven_sec_time));
// override with stale-while-revalidate
let meta = resp_cacheable_wrapper(
&build_response(
200,
&[(CACHE_CONTROL, "max-age=10, stale-while-revalidate=5")],
),
&DEFAULTS,
false,
);
let meta = meta.unwrap();
let eleven_sec_time = SystemTime::now()
.checked_add(Duration::from_secs(11))
.unwrap();
let sixteen_sec_time = SystemTime::now()
.checked_add(Duration::from_secs(16))
.unwrap();
assert!(!meta.is_fresh(eleven_sec_time));
assert!(meta.serve_stale_while_revalidate(eleven_sec_time));
assert!(!meta.serve_stale_while_revalidate(sixteen_sec_time));
}
#[test]
fn test_resp_stale_if_error() {
// respect defaults
let meta = resp_cacheable_wrapper(
&build_response(200, &[(CACHE_CONTROL, "max-age=10")]),
&DEFAULTS,
false,
);
let meta = meta.unwrap();
let hundred_years_time = SystemTime::now()
.checked_add(Duration::from_secs(86400 * 365 * 100))
.unwrap();
assert!(!meta.is_fresh(hundred_years_time));
assert!(meta.serve_stale_if_error(hundred_years_time));
// override with stale-if-error
let meta = resp_cacheable_wrapper(
&build_response(
200,
&[(
CACHE_CONTROL,
"max-age=10, stale-while-revalidate=5, stale-if-error=60",
)],
),
&DEFAULTS,
false,
);
let meta = meta.unwrap();
let eleven_sec_time = SystemTime::now()
.checked_add(Duration::from_secs(11))
.unwrap();
let seventy_sec_time = SystemTime::now()
.checked_add(Duration::from_secs(70))
.unwrap();
assert!(!meta.is_fresh(eleven_sec_time));
assert!(meta.serve_stale_if_error(SystemTime::now()));
assert!(meta.serve_stale_if_error(eleven_sec_time));
assert!(!meta.serve_stale_if_error(seventy_sec_time));
// never serve stale
let meta = resp_cacheable_wrapper(
&build_response(200, &[(CACHE_CONTROL, "max-age=10, stale-if-error=0")]),
&DEFAULTS,
false,
);
let meta = meta.unwrap();
let eleven_sec_time = SystemTime::now()
.checked_add(Duration::from_secs(11))
.unwrap();
assert!(!meta.is_fresh(eleven_sec_time));
assert!(!meta.serve_stale_if_error(eleven_sec_time));
}
#[test]
fn test_resp_status_cache_defaults() {
// 200 response
let meta = resp_cacheable_wrapper(&build_response(200, &[]), &DEFAULTS, false);
assert!(meta.is_some());
let meta = meta.unwrap();
assert!(meta.is_fresh(
SystemTime::now()
.checked_add(Duration::from_secs(9))
.unwrap()
));
assert!(!meta.is_fresh(
SystemTime::now()
.checked_add(Duration::from_secs(11))
.unwrap()
));
// 404 response, different ttl
let meta = resp_cacheable_wrapper(&build_response(404, &[]), &DEFAULTS, false);
assert!(meta.is_some());
let meta = meta.unwrap();
assert!(meta.is_fresh(
SystemTime::now()
.checked_add(Duration::from_secs(4))
.unwrap()
));
assert!(!meta.is_fresh(
SystemTime::now()
.checked_add(Duration::from_secs(6))
.unwrap()
));
// 206 marked uncacheable (no cache TTL)
let meta = resp_cacheable_wrapper(&build_response(206, &[]), &DEFAULTS, false);
assert!(meta.is_none());
// default uncacheable status with explicit Cache-Control is cacheable
let meta = resp_cacheable_wrapper(
&build_response(206, &[(CACHE_CONTROL, "public, max-age=10")]),
&DEFAULTS,
false,
);
assert!(meta.is_some());
let meta = meta.unwrap();
assert!(meta.is_fresh(
SystemTime::now()
.checked_add(Duration::from_secs(9))
.unwrap()
));
assert!(!meta.is_fresh(
SystemTime::now()
.checked_add(Duration::from_secs(11))
.unwrap()
));
// 416 matches any status
let meta = resp_cacheable_wrapper(&build_response(416, &[]), &DEFAULTS, false);
assert!(meta.is_some());
let meta = meta.unwrap();
assert!(meta.is_fresh(SystemTime::now()));
assert!(!meta.is_fresh(
SystemTime::now()
.checked_add(Duration::from_secs(2))
.unwrap()
));
}
#[test]
fn test_resp_cache_no_cache_fields() {
// check #field-names are stripped from the cache header
let meta = resp_cacheable_wrapper(
&build_response(
200,
&[
(SET_COOKIE, "my-cookie"),
(CACHE_CONTROL, "private=\"something\", max-age=10"),
(HeaderName::from_bytes(b"Something").unwrap(), "foo"),
],
),
&DEFAULTS,
false,
);
let meta = meta.unwrap();
assert!(meta.headers().contains_key(SET_COOKIE));
assert!(!meta.headers().contains_key("Something"));
let meta = resp_cacheable_wrapper(
&build_response(
200,
&[
(SET_COOKIE, "my-cookie"),
(
CACHE_CONTROL,
"max-age=0, no-cache=\"meta1, SeT-Cookie ,meta2\"",
),
(HeaderName::from_bytes(b"meta1").unwrap(), "foo"),
],
),
&DEFAULTS,
false,
);
let meta = meta.unwrap();
assert!(!meta.headers().contains_key(SET_COOKIE));
assert!(!meta.headers().contains_key("meta1"));
}
}

View file

@ -0,0 +1,112 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Concurrent hash tables and LRUs
use lru::LruCache;
use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::collections::HashMap;
// There are probably off-the-shelf crates of this, DashMap?
/// A hash table that shards to a constant number of tables to reduce lock contention
pub struct ConcurrentHashTable<V, const N: usize> {
tables: [RwLock<HashMap<u128, V>>; N],
}
#[inline]
fn get_shard(key: u128, n_shards: usize) -> usize {
(key % n_shards as u128) as usize
}
impl<V, const N: usize> ConcurrentHashTable<V, N>
where
[RwLock<HashMap<u128, V>>; N]: Default,
{
pub fn new() -> Self {
ConcurrentHashTable {
tables: Default::default(),
}
}
pub fn get(&self, key: u128) -> &RwLock<HashMap<u128, V>> {
&self.tables[get_shard(key, N)]
}
#[allow(dead_code)]
pub fn read(&self, key: u128) -> RwLockReadGuard<HashMap<u128, V>> {
self.get(key).read()
}
pub fn write(&self, key: u128) -> RwLockWriteGuard<HashMap<u128, V>> {
self.get(key).write()
}
// TODO: work out the lifetimes to provide get/set directly
}
impl<V, const N: usize> Default for ConcurrentHashTable<V, N>
where
[RwLock<HashMap<u128, V>>; N]: Default,
{
fn default() -> Self {
Self::new()
}
}
#[doc(hidden)] // not need in public API
pub struct LruShard<V>(RwLock<LruCache<u128, V>>);
impl<V> Default for LruShard<V> {
fn default() -> Self {
// help satisfy default construction of array
LruShard(RwLock::new(LruCache::unbounded()))
}
}
/// Sharded concurrent data structure for LruCache
pub struct ConcurrentLruCache<V, const N: usize> {
lrus: [LruShard<V>; N],
}
impl<V, const N: usize> ConcurrentLruCache<V, N>
where
[LruShard<V>; N]: Default,
{
pub fn new(shard_capacity: usize) -> Self {
use std::num::NonZeroUsize;
// safe, 1 != 0
const ONE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(1) };
let mut cache = ConcurrentLruCache {
lrus: Default::default(),
};
for lru in &mut cache.lrus {
lru.0
.write()
.resize(shard_capacity.try_into().unwrap_or(ONE));
}
cache
}
pub fn get(&self, key: u128) -> &RwLock<LruCache<u128, V>> {
&self.lrus[get_shard(key, N)].0
}
#[allow(dead_code)]
pub fn read(&self, key: u128) -> RwLockReadGuard<LruCache<u128, V>> {
self.get(key).read()
}
pub fn write(&self, key: u128) -> RwLockWriteGuard<LruCache<u128, V>> {
self.get(key).write()
}
// TODO: work out the lifetimes to provide get/set directly
}

302
pingora-cache/src/key.rs Normal file
View file

@ -0,0 +1,302 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Cache key
use super::*;
use blake2::{Blake2b, Digest};
use serde::{Deserialize, Serialize};
// 16-byte / 128-bit key: large enough to avoid collision
const KEY_SIZE: usize = 16;
/// An 128 bit hash binary
pub type HashBinary = [u8; KEY_SIZE];
fn hex2str(hex: &[u8]) -> String {
use std::fmt::Write;
let mut s = String::with_capacity(KEY_SIZE * 2);
for c in hex {
write!(s, "{:02x}", c).unwrap(); // safe, just dump hex to string
}
s
}
/// Decode the hex str into [HashBinary].
///
/// Return `None` when the decode fails or the input is not exact 32 (to decode to 16 bytes).
pub fn str2hex(s: &str) -> Option<HashBinary> {
if s.len() != KEY_SIZE * 2 {
return None;
}
let mut output = [0; KEY_SIZE];
// no need to bubble the error, it should be obvious why the decode fails
hex::decode_to_slice(s.as_bytes(), &mut output).ok()?;
Some(output)
}
/// The trait for cache key
pub trait CacheHashKey {
/// Return the hash of the cache key
fn primary_bin(&self) -> HashBinary;
/// Return the variance hash of the cache key.
///
/// `None` if no variance.
fn variance_bin(&self) -> Option<HashBinary>;
/// Return the hash including both primary and variance keys
fn combined_bin(&self) -> HashBinary {
let key = self.primary_bin();
if let Some(v) = self.variance_bin() {
let mut hasher = Blake2b128::new();
hasher.update(key);
hasher.update(v);
hasher.finalize().into()
} else {
// if there is no variance, combined_bin should return the same as primary_bin
key
}
}
/// An extra tag for identifying users
///
/// For example if the storage backend implements per user quota, this tag can be used.
fn user_tag(&self) -> &str;
/// The hex string of [Self::primary_bin()]
fn primary(&self) -> String {
hex2str(&self.primary_bin())
}
/// The hex string of [Self::variance_bin()]
fn variance(&self) -> Option<String> {
self.variance_bin().as_ref().map(|b| hex2str(&b[..]))
}
/// The hex string of [Self::combined_bin()]
fn combined(&self) -> String {
hex2str(&self.combined_bin())
}
}
/// General purpose cache key
#[derive(Debug, Clone)]
pub struct CacheKey {
// All strings for now, can be more structural as long as it can hash
namespace: String,
primary: String,
variance: Option<HashBinary>,
/// An extra tag for identifying users
///
/// For example if the storage backend implements per user quota, this tag can be used.
pub user_tag: String,
}
impl CacheKey {
/// Set the value of the variance hash
pub fn set_variance_key(&mut self, key: HashBinary) {
self.variance = Some(key)
}
/// Get the value of the variance hash
pub fn get_variance_key(&self) -> Option<&HashBinary> {
self.variance.as_ref()
}
/// Removes the variance from this cache key
pub fn remove_variance_key(&mut self) {
self.variance = None
}
}
/// Storage optimized cache key to keep in memory or in storage
// 16 bytes + 8 bytes (+16 * u8) + user_tag.len() + 16 Bytes (Box<str>)
#[derive(Debug, Deserialize, Serialize, Clone, Hash, PartialEq, Eq)]
pub struct CompactCacheKey {
pub primary: HashBinary,
// save 8 bytes for non-variance but waste 8 bytes for variance vs, store flat 16 bytes
pub variance: Option<Box<HashBinary>>,
pub user_tag: Box<str>, // the len should be small to keep memory usage bounded
}
impl CacheHashKey for CompactCacheKey {
fn primary_bin(&self) -> HashBinary {
self.primary
}
fn variance_bin(&self) -> Option<HashBinary> {
self.variance.as_ref().map(|s| *s.as_ref())
}
fn user_tag(&self) -> &str {
&self.user_tag
}
}
/*
* We use blake2 hashing, which is faster and more secure, to replace md5.
* We have not given too much thought on whether non-crypto hash can be safely
* use because hashing performance is not critical.
* Note: we should avoid hashes like ahash which does not have consistent output
* across machines because it is designed purely for in memory hashtable
*/
// hash output: we use 128 bits (16 bytes) hash which will map to 32 bytes hex string
pub(crate) type Blake2b128 = Blake2b<blake2::digest::consts::U16>;
/// helper function: hash str to u8
pub fn hash_u8(key: &str) -> u8 {
let mut hasher = Blake2b128::new();
hasher.update(key);
let raw = hasher.finalize();
raw[0]
}
/// helper function: hash str to [HashBinary]
pub fn hash_key(key: &str) -> HashBinary {
let mut hasher = Blake2b128::new();
hasher.update(key);
let raw = hasher.finalize();
raw.into()
}
impl CacheKey {
fn primary_hasher(&self) -> Blake2b128 {
let mut hasher = Blake2b128::new();
hasher.update(&self.namespace);
hasher.update(&self.primary);
hasher
}
/// Create a default [CacheKey] from a request, which just takes it URI as the primary key.
pub fn default(req_header: &ReqHeader) -> Self {
CacheKey {
namespace: "".into(),
primary: format!("{}", req_header.uri),
variance: None,
user_tag: "".into(),
}
}
/// Create a new [CacheKey] from the given namespace, primary, and user_tag string.
///
/// Both `namespace` and `primary` will be used for the primary hash
pub fn new<S1, S2, S3>(namespace: S1, primary: S2, user_tag: S3) -> Self
where
S1: Into<String>,
S2: Into<String>,
S3: Into<String>,
{
CacheKey {
namespace: namespace.into(),
primary: primary.into(),
variance: None,
user_tag: user_tag.into(),
}
}
/// Return the namespace of this key
pub fn namespace(&self) -> &str {
&self.namespace
}
/// Return the primary key of this key
pub fn primary_key(&self) -> &str {
&self.primary
}
/// Convert this key to [CompactCacheKey].
pub fn to_compact(&self) -> CompactCacheKey {
let primary = self.primary_bin();
CompactCacheKey {
primary,
variance: self.variance_bin().map(Box::new),
user_tag: self.user_tag.clone().into_boxed_str(),
}
}
}
impl CacheHashKey for CacheKey {
fn primary_bin(&self) -> HashBinary {
self.primary_hasher().finalize().into()
}
fn variance_bin(&self) -> Option<HashBinary> {
self.variance
}
fn user_tag(&self) -> &str {
&self.user_tag
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_key_hash() {
let key = CacheKey {
namespace: "".into(),
primary: "aa".into(),
variance: None,
user_tag: "1".into(),
};
let hash = key.primary();
assert_eq!(hash, "ac10f2aef117729f8dad056b3059eb7e");
assert!(key.variance().is_none());
assert_eq!(key.combined(), hash);
let compact = key.to_compact();
assert_eq!(compact.primary(), hash);
assert!(compact.variance().is_none());
assert_eq!(compact.combined(), hash);
}
#[test]
fn test_cache_key_vary_hash() {
let key = CacheKey {
namespace: "".into(),
primary: "aa".into(),
variance: Some([0u8; 16]),
user_tag: "1".into(),
};
let hash = key.primary();
assert_eq!(hash, "ac10f2aef117729f8dad056b3059eb7e");
assert_eq!(key.variance().unwrap(), "00000000000000000000000000000000");
assert_eq!(key.combined(), "004174d3e75a811a5b44c46b3856f3ee");
let compact = key.to_compact();
assert_eq!(compact.primary(), "ac10f2aef117729f8dad056b3059eb7e");
assert_eq!(
compact.variance().unwrap(),
"00000000000000000000000000000000"
);
assert_eq!(compact.combined(), "004174d3e75a811a5b44c46b3856f3ee");
}
#[test]
fn test_hex_str() {
let mut key = [0; KEY_SIZE];
for (i, v) in key.iter_mut().enumerate() {
// key: [0, 1, 2, .., 15]
*v = i as u8;
}
let hex_str = hex2str(&key);
let key2 = str2hex(&hex_str).unwrap();
for i in 0..KEY_SIZE {
assert_eq!(key[i], key2[i]);
}
}
}

1093
pingora-cache/src/lib.rs Normal file

File diff suppressed because it is too large Load diff

336
pingora-cache/src/lock.rs Normal file
View file

@ -0,0 +1,336 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Cache lock
use crate::key::CacheHashKey;
use crate::hashtable::ConcurrentHashTable;
use pingora_timeout::timeout;
use std::sync::Arc;
const N_SHARDS: usize = 16;
/// The global cache locking manager
pub struct CacheLock {
lock_table: ConcurrentHashTable<LockStub, N_SHARDS>,
timeout: Duration, // fixed timeout value for now
}
/// A struct prepresenting a locked cache access
#[derive(Debug)]
pub enum Locked {
/// The writer is allowed to fetch the asset
Write(WritePermit),
/// The reader waits for the writer to fetch the asset
Read(ReadLock),
}
impl Locked {
/// Is this a write lock
pub fn is_write(&self) -> bool {
matches!(self, Self::Write(_))
}
}
impl CacheLock {
/// Create a new [CacheLock] with the given lock timeout
///
/// When the timeout is reached, the read locks are automatically unlocked
pub fn new(timeout: Duration) -> Self {
CacheLock {
lock_table: ConcurrentHashTable::new(),
timeout,
}
}
/// Try to lock a cache fetch
///
/// Users should call after a cache miss before fetching the asset.
/// The returned [Locked] will tell the caller either to fetch or wait.
pub fn lock<K: CacheHashKey>(&self, key: &K) -> Locked {
let hash = key.combined_bin();
let key = u128::from_be_bytes(hash); // endianness doesn't matter
let table = self.lock_table.get(key);
if let Some(lock) = table.read().get(&key) {
// already has an ongoing request
if lock.0.lock_status() != LockStatus::Dangling {
return Locked::Read(lock.read_lock());
}
// Dangling: the previous writer quit without unlocking the lock. Requests should
// compete for the write lock again.
}
let (permit, stub) = WritePermit::new(self.timeout);
let mut table = table.write();
// check again in case another request already added it
if let Some(lock) = table.get(&key) {
if lock.0.lock_status() != LockStatus::Dangling {
return Locked::Read(lock.read_lock());
}
}
table.insert(key, stub);
Locked::Write(permit)
}
/// Release a lock for the given key
///
/// When the write lock is dropped without being released, the read lock holders will consider
/// it to be failed so that they will compete for the write lock again.
pub fn release<K: CacheHashKey>(&self, key: &K, reason: LockStatus) {
let hash = key.combined_bin();
let key = u128::from_be_bytes(hash); // endianness doesn't matter
if let Some(lock) = self.lock_table.write(key).remove(&key) {
// make sure that the caller didn't forget to unlock it
if lock.0.locked() {
lock.0.unlock(reason);
}
}
}
}
use std::sync::atomic::{AtomicU8, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::Semaphore;
/// Status which the read locks could possibly see.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum LockStatus {
/// Waiting for the writer to populate the asset
Waiting,
/// The writer finishes, readers can start
Done,
/// The writer encountered error, such as network issue. A new writer will be elected.
TransientError,
/// The writer observed that no cache lock is needed (e.g., uncacheable), readers should start
/// to fetch independently without a new writer
GiveUp,
/// The write lock is dropped without being unlocked
Dangling,
/// The lock is held for too long
Timeout,
}
impl From<LockStatus> for u8 {
fn from(l: LockStatus) -> u8 {
match l {
LockStatus::Waiting => 0,
LockStatus::Done => 1,
LockStatus::TransientError => 2,
LockStatus::GiveUp => 3,
LockStatus::Dangling => 4,
LockStatus::Timeout => 5,
}
}
}
impl From<u8> for LockStatus {
fn from(v: u8) -> Self {
match v {
0 => Self::Waiting,
1 => Self::Done,
2 => Self::TransientError,
3 => Self::GiveUp,
4 => Self::Dangling,
5 => Self::Timeout,
_ => Self::GiveUp, // placeholder
}
}
}
#[derive(Debug)]
struct LockCore {
pub lock_start: Instant,
pub timeout: Duration,
pub(super) lock: Semaphore,
// use u8 for Atomic enum
lock_status: AtomicU8,
}
impl LockCore {
pub fn new_arc(timeout: Duration) -> Arc<Self> {
Arc::new(LockCore {
lock: Semaphore::new(0),
timeout,
lock_start: Instant::now(),
lock_status: AtomicU8::new(LockStatus::Waiting.into()),
})
}
fn locked(&self) -> bool {
self.lock.available_permits() == 0
}
fn unlock(&self, reason: LockStatus) {
self.lock_status.store(reason.into(), Ordering::SeqCst);
// any small positive number will do, 10 is used for RwLock too
// no need to wake up all at once
self.lock.add_permits(10);
}
fn lock_status(&self) -> LockStatus {
self.lock_status.load(Ordering::Relaxed).into()
}
}
// all 3 structs below are just Arc<LockCore> with different interfaces
/// ReadLock: requests who get it need to wait until it is released
#[derive(Debug)]
pub struct ReadLock(Arc<LockCore>);
impl ReadLock {
/// Wait for the writer to release the lock
pub async fn wait(&self) {
if !self.locked() || self.expired() {
return;
}
// TODO: should subtract now - start so that the lock don't wait beyond start + timeout
// Also need to be careful not to wake everyone up at the same time
// (maybe not an issue because regular cache lock release behaves that way)
let _ = timeout(self.0.timeout, self.0.lock.acquire()).await;
// permit is returned to Semaphore right away
}
/// Test if it is still locked
pub fn locked(&self) -> bool {
self.0.locked()
}
/// Whether the lock is expired, e.g., the writer has been holding the lock for too long
pub fn expired(&self) -> bool {
// NOTE: this whether the lock is currently expired
// not whether it was timed out during wait()
self.0.lock_start.elapsed() >= self.0.timeout
}
/// The current status of the lock
pub fn lock_status(&self) -> LockStatus {
let status = self.0.lock_status();
if matches!(status, LockStatus::Waiting) && self.expired() {
LockStatus::Timeout
} else {
status
}
}
}
/// WritePermit: requires who get it need to populate the cache and then release it
#[derive(Debug)]
pub struct WritePermit(Arc<LockCore>);
impl WritePermit {
fn new(timeout: Duration) -> (WritePermit, LockStub) {
let lock = LockCore::new_arc(timeout);
let stub = LockStub(lock.clone());
(WritePermit(lock), stub)
}
fn unlock(&self, reason: LockStatus) {
self.0.unlock(reason)
}
}
impl Drop for WritePermit {
fn drop(&mut self) {
// writer exit without properly unlock, let others to compete for the write lock again
if self.0.locked() {
self.unlock(LockStatus::Dangling);
}
}
}
struct LockStub(Arc<LockCore>);
impl LockStub {
pub fn read_lock(&self) -> ReadLock {
ReadLock(self.0.clone())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::CacheKey;
#[test]
fn test_get_release() {
let cache_lock = CacheLock::new(Duration::from_secs(1000));
let key1 = CacheKey::new("", "a", "1");
let locked1 = cache_lock.lock(&key1);
assert!(locked1.is_write()); // write permit
let locked2 = cache_lock.lock(&key1);
assert!(!locked2.is_write()); // read lock
cache_lock.release(&key1, LockStatus::Done);
let locked3 = cache_lock.lock(&key1);
assert!(locked3.is_write()); // write permit again
}
#[tokio::test]
async fn test_lock() {
let cache_lock = CacheLock::new(Duration::from_secs(1000));
let key1 = CacheKey::new("", "a", "1");
let permit = match cache_lock.lock(&key1) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock.locked());
let handle = tokio::spawn(async move {
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::Done);
});
permit.unlock(LockStatus::Done);
handle.await.unwrap(); // check lock is unlocked and the task is returned
}
#[tokio::test]
async fn test_lock_timeout() {
let cache_lock = CacheLock::new(Duration::from_secs(1));
let key1 = CacheKey::new("", "a", "1");
let permit = match cache_lock.lock(&key1) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock.locked());
let handle = tokio::spawn(async move {
// timed out
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::Timeout);
});
tokio::time::sleep(Duration::from_secs(2)).await;
// expired lock
let lock2 = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock2.locked());
assert_eq!(lock2.lock_status(), LockStatus::Timeout);
lock2.wait().await;
assert_eq!(lock2.lock_status(), LockStatus::Timeout);
permit.unlock(LockStatus::Done);
handle.await.unwrap();
}
}

View file

@ -0,0 +1,75 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Set limit on the largest size to cache
use crate::storage::HandleMiss;
use crate::MissHandler;
use async_trait::async_trait;
use bytes::Bytes;
use pingora_error::{Error, ErrorType};
/// [MaxFileSizeMissHandler] wraps a MissHandler to enforce a maximum asset size that should be
/// written to the MissHandler.
///
/// This is used to enforce a maximum cache size for a request when the
/// response size is not known ahead of time (no Content-Length header). When the response size _is_
/// known ahead of time, it should be checked up front (when calculating cacheability) for efficiency.
/// Note: for requests with partial read support (where downstream reads the response from cache as
/// it is filled), this will cause the request as a whole to fail. The response will be remembered
/// as uncacheable, though, so downstream will be able to retry the request, since the cache will be
/// disabled for the retried request.
pub struct MaxFileSizeMissHandler {
inner: MissHandler,
max_file_size_bytes: usize,
bytes_written: usize,
}
impl MaxFileSizeMissHandler {
/// Create a new [MaxFileSizeMissHandler] wrapping the given [MissHandler]
pub fn new(inner: MissHandler, max_file_size_bytes: usize) -> MaxFileSizeMissHandler {
MaxFileSizeMissHandler {
inner,
max_file_size_bytes,
bytes_written: 0,
}
}
}
/// Error type returned when the limit is reached.
pub const ERR_RESPONSE_TOO_LARGE: ErrorType = ErrorType::Custom("response too large");
#[async_trait]
impl HandleMiss for MaxFileSizeMissHandler {
async fn write_body(&mut self, data: Bytes, eof: bool) -> pingora_error::Result<()> {
// fail if writing the body would exceed the max_file_size_bytes
if self.bytes_written + data.len() > self.max_file_size_bytes {
return Error::e_explain(
ERR_RESPONSE_TOO_LARGE,
format!(
"writing data of size {} bytes would exceed max file size of {} bytes",
data.len(),
self.max_file_size_bytes
),
);
}
self.bytes_written += data.len();
self.inner.write_body(data, eof).await
}
async fn finish(self: Box<Self>) -> pingora_error::Result<usize> {
self.inner.finish().await
}
}

510
pingora-cache/src/memory.rs Normal file
View file

@ -0,0 +1,510 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Hash map based in memory cache
//!
//! For testing only, not for production use
//TODO: Mark this module #[test] only
use super::*;
use crate::key::{CacheHashKey, CompactCacheKey};
use crate::storage::{HandleHit, HandleMiss, Storage};
use crate::trace::SpanHandle;
use async_trait::async_trait;
use bytes::Bytes;
use parking_lot::RwLock;
use pingora_error::*;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::watch;
type BinaryMeta = (Vec<u8>, Vec<u8>);
pub(crate) struct CacheObject {
pub meta: BinaryMeta,
pub body: Arc<Vec<u8>>,
}
pub(crate) struct TempObject {
pub meta: BinaryMeta,
// these are Arc because they need to continue exist after this TempObject is removed
pub body: Arc<RwLock<Vec<u8>>>,
bytes_written: Arc<watch::Sender<PartialState>>, // this should match body.len()
}
impl TempObject {
fn new(meta: BinaryMeta) -> Self {
let (tx, _rx) = watch::channel(PartialState::Partial(0));
TempObject {
meta,
body: Arc::new(RwLock::new(Vec::new())),
bytes_written: Arc::new(tx),
}
}
// this is not at all optimized
fn make_cache_object(&self) -> CacheObject {
let meta = self.meta.clone();
let body = Arc::new(self.body.read().clone());
CacheObject { meta, body }
}
}
/// Hash map based in memory cache
///
/// For testing only, not for production use.
pub struct MemCache {
pub(crate) cached: Arc<RwLock<HashMap<String, CacheObject>>>,
pub(crate) temp: Arc<RwLock<HashMap<String, TempObject>>>,
}
impl MemCache {
/// Create a new [MemCache]
pub fn new() -> Self {
MemCache {
cached: Arc::new(RwLock::new(HashMap::new())),
temp: Arc::new(RwLock::new(HashMap::new())),
}
}
}
pub enum MemHitHandler {
Complete(CompleteHit),
Partial(PartialHit),
}
#[derive(Copy, Clone)]
enum PartialState {
Partial(usize),
Complete(usize),
}
pub struct CompleteHit {
body: Arc<Vec<u8>>,
done: bool,
range_start: usize,
range_end: usize,
}
impl CompleteHit {
fn get(&mut self) -> Option<Bytes> {
if self.done {
None
} else {
self.done = true;
Some(Bytes::copy_from_slice(
&self.body.as_slice()[self.range_start..self.range_end],
))
}
}
fn seek(&mut self, start: usize, end: Option<usize>) -> Result<()> {
if start >= self.body.len() {
return Error::e_explain(
ErrorType::InternalError,
format!("seek start out of range {start} >= {}", self.body.len()),
);
}
self.range_start = start;
if let Some(end) = end {
// end over the actual last byte is allowed, we just need to return the actual bytes
self.range_end = std::cmp::min(self.body.len(), end);
}
// seek resets read so that one handler can be used for multiple ranges
self.done = false;
Ok(())
}
}
pub struct PartialHit {
body: Arc<RwLock<Vec<u8>>>,
bytes_written: watch::Receiver<PartialState>,
bytes_read: usize,
}
impl PartialHit {
async fn read(&mut self) -> Option<Bytes> {
loop {
let bytes_written = *self.bytes_written.borrow_and_update();
let bytes_end = match bytes_written {
PartialState::Partial(s) => s,
PartialState::Complete(c) => {
// no more data will arrive
if c == self.bytes_read {
return None;
}
c
}
};
assert!(bytes_end >= self.bytes_read);
// more data avaliable to read
if bytes_end > self.bytes_read {
let new_bytes =
Bytes::copy_from_slice(&self.body.read()[self.bytes_read..bytes_end]);
self.bytes_read = bytes_end;
return Some(new_bytes);
}
// wait for more data
if self.bytes_written.changed().await.is_err() {
// err: sender dropped, body is finished
// FIXME: sender could drop because of an error
return None;
}
}
}
}
#[async_trait]
impl HandleHit for MemHitHandler {
async fn read_body(&mut self) -> Result<Option<Bytes>> {
match self {
Self::Complete(c) => Ok(c.get()),
Self::Partial(p) => Ok(p.read().await),
}
}
async fn finish(
self: Box<Self>, // because self is always used as a trait object
_storage: &'static (dyn storage::Storage + Sync),
_key: &CacheKey,
_trace: &SpanHandle,
) -> Result<()> {
Ok(())
}
fn can_seek(&self) -> bool {
match self {
Self::Complete(_) => true,
Self::Partial(_) => false, // TODO: support seeking in partial reads
}
}
fn seek(&mut self, start: usize, end: Option<usize>) -> Result<()> {
match self {
Self::Complete(c) => c.seek(start, end),
Self::Partial(_) => Error::e_explain(
ErrorType::InternalError,
"seek not supported for partial cache",
),
}
}
fn as_any(&self) -> &(dyn Any + Send + Sync) {
self
}
}
pub struct MemMissHandler {
body: Arc<RwLock<Vec<u8>>>,
bytes_written: Arc<watch::Sender<PartialState>>,
// these are used only in finish() to to data from temp to cache
key: String,
cache: Arc<RwLock<HashMap<String, CacheObject>>>,
temp: Arc<RwLock<HashMap<String, TempObject>>>,
}
#[async_trait]
impl HandleMiss for MemMissHandler {
async fn write_body(&mut self, data: bytes::Bytes, eof: bool) -> Result<()> {
let current_bytes = match *self.bytes_written.borrow() {
PartialState::Partial(p) => p,
PartialState::Complete(_) => panic!("already EOF"),
};
self.body.write().extend_from_slice(&data);
let written = current_bytes + data.len();
let new_state = if eof {
PartialState::Complete(written)
} else {
PartialState::Partial(written)
};
self.bytes_written.send_replace(new_state);
Ok(())
}
async fn finish(self: Box<Self>) -> Result<usize> {
// safe, the temp object is inserted when the miss handler is created
let cache_object = self.temp.read().get(&self.key).unwrap().make_cache_object();
let size = cache_object.body.len(); // FIXME: this just body size, also track meta size
self.cache.write().insert(self.key.clone(), cache_object);
self.temp.write().remove(&self.key);
Ok(size)
}
}
impl Drop for MemMissHandler {
fn drop(&mut self) {
self.temp.write().remove(&self.key);
}
}
#[async_trait]
impl Storage for MemCache {
async fn lookup(
&'static self,
key: &CacheKey,
_trace: &SpanHandle,
) -> Result<Option<(CacheMeta, HitHandler)>> {
let hash = key.combined();
// always prefer partial read otherwise fresh asset will not be visible on expired asset
// until it is fully updated
if let Some(temp_obj) = self.temp.read().get(&hash) {
let meta = CacheMeta::deserialize(&temp_obj.meta.0, &temp_obj.meta.1)?;
let partial = PartialHit {
body: temp_obj.body.clone(),
bytes_written: temp_obj.bytes_written.subscribe(),
bytes_read: 0,
};
let hit_handler = MemHitHandler::Partial(partial);
Ok(Some((meta, Box::new(hit_handler))))
} else if let Some(obj) = self.cached.read().get(&hash) {
let meta = CacheMeta::deserialize(&obj.meta.0, &obj.meta.1)?;
let hit_handler = CompleteHit {
body: obj.body.clone(),
done: false,
range_start: 0,
range_end: obj.body.len(),
};
let hit_handler = MemHitHandler::Complete(hit_handler);
Ok(Some((meta, Box::new(hit_handler))))
} else {
Ok(None)
}
}
async fn get_miss_handler(
&'static self,
key: &CacheKey,
meta: &CacheMeta,
_trace: &SpanHandle,
) -> Result<MissHandler> {
// TODO: support multiple concurrent writes or panic if the is already a writer
let hash = key.combined();
let meta = meta.serialize()?;
let temp_obj = TempObject::new(meta);
let miss_handler = MemMissHandler {
body: temp_obj.body.clone(),
bytes_written: temp_obj.bytes_written.clone(),
key: hash.clone(),
cache: self.cached.clone(),
temp: self.temp.clone(),
};
self.temp.write().insert(hash, temp_obj);
Ok(Box::new(miss_handler))
}
async fn purge(&'static self, key: &CompactCacheKey, _trace: &SpanHandle) -> Result<bool> {
// TODO: purge partial
// This usually purges the primary key because, without a lookup, variance key is usually
// empty
let hash = key.combined();
Ok(self.cached.write().remove(&hash).is_some())
}
async fn update_meta(
&'static self,
key: &CacheKey,
meta: &CacheMeta,
_trace: &SpanHandle,
) -> Result<bool> {
let hash = key.combined();
if let Some(obj) = self.cached.write().get_mut(&hash) {
obj.meta = meta.serialize()?;
Ok(true)
} else {
panic!("no meta found")
}
}
fn support_streaming_partial_write(&self) -> bool {
true
}
fn as_any(&self) -> &(dyn Any + Send + Sync) {
self
}
}
#[cfg(test)]
mod test {
use super::*;
use once_cell::sync::Lazy;
use rustracing::span::Span;
fn gen_meta() -> CacheMeta {
let mut header = ResponseHeader::build(200, None).unwrap();
header.append_header("foo1", "bar1").unwrap();
header.append_header("foo2", "bar2").unwrap();
header.append_header("foo3", "bar3").unwrap();
header.append_header("Server", "Pingora").unwrap();
let internal = crate::meta::InternalMeta::default();
CacheMeta(Box::new(crate::meta::CacheMetaInner {
internal,
header,
extensions: http::Extensions::new(),
}))
}
#[tokio::test]
async fn test_write_then_read() {
static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
let span = &Span::inactive().handle();
let key1 = CacheKey::new("", "a", "1");
let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
assert!(res.is_none());
let cache_meta = gen_meta();
let mut miss_handler = MEM_CACHE
.get_miss_handler(&key1, &cache_meta, span)
.await
.unwrap();
miss_handler
.write_body(b"test1"[..].into(), false)
.await
.unwrap();
miss_handler
.write_body(b"test2"[..].into(), false)
.await
.unwrap();
miss_handler.finish().await.unwrap();
let (cache_meta2, mut hit_handler) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
assert_eq!(
cache_meta.0.internal.fresh_until,
cache_meta2.0.internal.fresh_until
);
let data = hit_handler.read_body().await.unwrap().unwrap();
assert_eq!("test1test2", data);
let data = hit_handler.read_body().await.unwrap();
assert!(data.is_none());
}
#[tokio::test]
async fn test_read_range() {
static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
let span = &Span::inactive().handle();
let key1 = CacheKey::new("", "a", "1");
let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
assert!(res.is_none());
let cache_meta = gen_meta();
let mut miss_handler = MEM_CACHE
.get_miss_handler(&key1, &cache_meta, span)
.await
.unwrap();
miss_handler
.write_body(b"test1test2"[..].into(), false)
.await
.unwrap();
miss_handler.finish().await.unwrap();
let (cache_meta2, mut hit_handler) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
assert_eq!(
cache_meta.0.internal.fresh_until,
cache_meta2.0.internal.fresh_until
);
// out of range
assert!(hit_handler.seek(10000, None).is_err());
assert!(hit_handler.seek(5, None).is_ok());
let data = hit_handler.read_body().await.unwrap().unwrap();
assert_eq!("test2", data);
let data = hit_handler.read_body().await.unwrap();
assert!(data.is_none());
assert!(hit_handler.seek(4, Some(5)).is_ok());
let data = hit_handler.read_body().await.unwrap().unwrap();
assert_eq!("1", data);
let data = hit_handler.read_body().await.unwrap();
assert!(data.is_none());
}
#[tokio::test]
async fn test_write_while_read() {
use futures::FutureExt;
static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
let span = &Span::inactive().handle();
let key1 = CacheKey::new("", "a", "1");
let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
assert!(res.is_none());
let cache_meta = gen_meta();
let mut miss_handler = MEM_CACHE
.get_miss_handler(&key1, &cache_meta, span)
.await
.unwrap();
// first reader
let (cache_meta1, mut hit_handler1) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
assert_eq!(
cache_meta.0.internal.fresh_until,
cache_meta1.0.internal.fresh_until
);
// No body to read
let res = hit_handler1.read_body().now_or_never();
assert!(res.is_none());
miss_handler
.write_body(b"test1"[..].into(), false)
.await
.unwrap();
let data = hit_handler1.read_body().await.unwrap().unwrap();
assert_eq!("test1", data);
let res = hit_handler1.read_body().now_or_never();
assert!(res.is_none());
miss_handler
.write_body(b"test2"[..].into(), false)
.await
.unwrap();
let data = hit_handler1.read_body().await.unwrap().unwrap();
assert_eq!("test2", data);
// second reader
let (cache_meta2, mut hit_handler2) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
assert_eq!(
cache_meta.0.internal.fresh_until,
cache_meta2.0.internal.fresh_until
);
let data = hit_handler2.read_body().await.unwrap().unwrap();
assert_eq!("test1test2", data);
let res = hit_handler2.read_body().now_or_never();
assert!(res.is_none());
let res = hit_handler1.read_body().now_or_never();
assert!(res.is_none());
miss_handler.finish().await.unwrap();
let data = hit_handler1.read_body().await.unwrap();
assert!(data.is_none());
let data = hit_handler2.read_body().await.unwrap();
assert!(data.is_none());
}
}

608
pingora-cache/src/meta.rs Normal file
View file

@ -0,0 +1,608 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Metadata for caching
use http::Extensions;
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use pingora_http::{HMap, ResponseHeader};
use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime};
use crate::key::HashBinary;
pub(crate) type InternalMeta = internal_meta::InternalMetaLatest;
mod internal_meta {
use super::*;
pub(crate) type InternalMetaLatest = InternalMetaV2;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub(crate) struct InternalMetaV0 {
pub(crate) fresh_until: SystemTime,
pub(crate) created: SystemTime,
pub(crate) stale_while_revalidate_sec: u32,
pub(crate) stale_if_error_sec: u32,
// Do not add more field
}
impl InternalMetaV0 {
#[allow(dead_code)]
fn serialize(&self) -> Result<Vec<u8>> {
rmp_serde::encode::to_vec(self).or_err(InternalError, "failed to encode cache meta")
}
fn deserialize(buf: &[u8]) -> Result<Self> {
rmp_serde::decode::from_slice(buf)
.or_err(InternalError, "failed to decode cache meta v0")
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub(crate) struct InternalMetaV1 {
pub(crate) version: u8,
pub(crate) fresh_until: SystemTime,
pub(crate) created: SystemTime,
pub(crate) stale_while_revalidate_sec: u32,
pub(crate) stale_if_error_sec: u32,
// Do not add more field
}
impl InternalMetaV1 {
#[allow(dead_code)]
pub const VERSION: u8 = 1;
#[allow(dead_code)]
pub fn serialize(&self) -> Result<Vec<u8>> {
assert_eq!(self.version, 1);
rmp_serde::encode::to_vec(self).or_err(InternalError, "failed to encode cache meta")
}
fn deserialize(buf: &[u8]) -> Result<Self> {
rmp_serde::decode::from_slice(buf)
.or_err(InternalError, "failed to decode cache meta v1")
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub(crate) struct InternalMetaV2 {
pub(crate) version: u8,
pub(crate) fresh_until: SystemTime,
pub(crate) created: SystemTime,
pub(crate) updated: SystemTime,
pub(crate) stale_while_revalidate_sec: u32,
pub(crate) stale_if_error_sec: u32,
// Only the extended field to be added below. One field at a time.
// 1. serde default in order to accept an older version schema without the field existing
// 2. serde skip_serializing_if in order for software with only an older version of this
// schema to decode it
// After full releases, remove `skip_serializing_if` so that we can add the next extended field.
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) variance: Option<HashBinary>,
}
impl Default for InternalMetaV2 {
fn default() -> Self {
let epoch = SystemTime::UNIX_EPOCH;
InternalMetaV2 {
version: InternalMetaV2::VERSION,
fresh_until: epoch,
created: epoch,
updated: epoch,
stale_while_revalidate_sec: 0,
stale_if_error_sec: 0,
variance: None,
}
}
}
impl InternalMetaV2 {
pub const VERSION: u8 = 2;
pub fn serialize(&self) -> Result<Vec<u8>> {
assert_eq!(self.version, Self::VERSION);
rmp_serde::encode::to_vec(self).or_err(InternalError, "failed to encode cache meta")
}
fn deserialize(buf: &[u8]) -> Result<Self> {
rmp_serde::decode::from_slice(buf)
.or_err(InternalError, "failed to decode cache meta v2")
}
}
impl From<InternalMetaV0> for InternalMetaV2 {
fn from(v0: InternalMetaV0) -> Self {
InternalMetaV2 {
version: InternalMetaV2::VERSION,
fresh_until: v0.fresh_until,
created: v0.created,
updated: v0.created,
stale_while_revalidate_sec: v0.stale_while_revalidate_sec,
stale_if_error_sec: v0.stale_if_error_sec,
..Default::default()
}
}
}
impl From<InternalMetaV1> for InternalMetaV2 {
fn from(v1: InternalMetaV1) -> Self {
InternalMetaV2 {
version: InternalMetaV2::VERSION,
fresh_until: v1.fresh_until,
created: v1.created,
updated: v1.created,
stale_while_revalidate_sec: v1.stale_while_revalidate_sec,
stale_if_error_sec: v1.stale_if_error_sec,
..Default::default()
}
}
}
// cross version decode
pub(crate) fn deserialize(buf: &[u8]) -> Result<InternalMetaLatest> {
const MIN_SIZE: usize = 10; // a small number to read the first few bytes
if buf.len() < MIN_SIZE {
return Error::e_explain(
InternalError,
format!("Buf too short ({}) to be InternalMeta", buf.len()),
);
}
let preread_buf = &mut &buf[..MIN_SIZE];
// the struct is always packed as a fixed size array
match rmp::decode::read_array_len(preread_buf)
.or_err(InternalError, "failed to decode cache meta array size")?
{
// v0 has 4 items and no version number
4 => Ok(InternalMetaV0::deserialize(buf)?.into()),
// other V should has version number encoded
_ => {
// rmp will encode version < 128 into a fixint (one byte),
// so we use read_pfix
let version = rmp::decode::read_pfix(preread_buf)
.or_err(InternalError, "failed to decode meta version")?;
match version {
1 => Ok(InternalMetaV1::deserialize(buf)?.into()),
2 => InternalMetaV2::deserialize(buf),
_ => Error::e_explain(
InternalError,
format!("Unknown InternalMeta version {version}"),
),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_internal_meta_serde_v0() {
let meta = InternalMetaV0 {
fresh_until: SystemTime::now(),
created: SystemTime::now(),
stale_while_revalidate_sec: 0,
stale_if_error_sec: 0,
};
let binary = meta.serialize().unwrap();
let meta2 = InternalMetaV0::deserialize(&binary).unwrap();
assert_eq!(meta.fresh_until, meta2.fresh_until);
}
#[test]
fn test_internal_meta_serde_v1() {
let meta = InternalMetaV1 {
version: InternalMetaV1::VERSION,
fresh_until: SystemTime::now(),
created: SystemTime::now(),
stale_while_revalidate_sec: 0,
stale_if_error_sec: 0,
};
let binary = meta.serialize().unwrap();
let meta2 = InternalMetaV1::deserialize(&binary).unwrap();
assert_eq!(meta.fresh_until, meta2.fresh_until);
}
#[test]
fn test_internal_meta_serde_v2() {
let meta = InternalMetaV2::default();
let binary = meta.serialize().unwrap();
let meta2 = InternalMetaV2::deserialize(&binary).unwrap();
assert_eq!(meta2.version, 2);
assert_eq!(meta.fresh_until, meta2.fresh_until);
assert_eq!(meta.created, meta2.created);
assert_eq!(meta.updated, meta2.updated);
}
#[test]
fn test_internal_meta_serde_across_versions() {
let meta = InternalMetaV0 {
fresh_until: SystemTime::now(),
created: SystemTime::now(),
stale_while_revalidate_sec: 0,
stale_if_error_sec: 0,
};
let binary = meta.serialize().unwrap();
let meta2 = deserialize(&binary).unwrap();
assert_eq!(meta2.version, 2);
assert_eq!(meta.fresh_until, meta2.fresh_until);
let meta = InternalMetaV1 {
version: 1,
fresh_until: SystemTime::now(),
created: SystemTime::now(),
stale_while_revalidate_sec: 0,
stale_if_error_sec: 0,
};
let binary = meta.serialize().unwrap();
let meta2 = deserialize(&binary).unwrap();
assert_eq!(meta2.version, 2);
assert_eq!(meta.fresh_until, meta2.fresh_until);
// `updated` == `created` when upgrading to v2
assert_eq!(meta2.created, meta2.updated);
}
#[test]
fn test_internal_meta_serde_v2_extend_fields() {
// make sure that v2 format is backward compatible
// this is the base version of v2 without any extended fields
#[derive(Deserialize, Serialize)]
pub(crate) struct InternalMetaV2Base {
pub(crate) version: u8,
pub(crate) fresh_until: SystemTime,
pub(crate) created: SystemTime,
pub(crate) updated: SystemTime,
pub(crate) stale_while_revalidate_sec: u32,
pub(crate) stale_if_error_sec: u32,
}
impl InternalMetaV2Base {
pub const VERSION: u8 = 2;
pub fn serialize(&self) -> Result<Vec<u8>> {
assert!(self.version >= Self::VERSION);
rmp_serde::encode::to_vec(self)
.or_err(InternalError, "failed to encode cache meta")
}
fn deserialize(buf: &[u8]) -> Result<Self> {
rmp_serde::decode::from_slice(buf)
.or_err(InternalError, "failed to decode cache meta v2")
}
}
// ext V2 to base v2
let meta = InternalMetaV2::default();
let binary = meta.serialize().unwrap();
let meta2 = InternalMetaV2Base::deserialize(&binary).unwrap();
assert_eq!(meta2.version, 2);
assert_eq!(meta.fresh_until, meta2.fresh_until);
assert_eq!(meta.created, meta2.created);
assert_eq!(meta.updated, meta2.updated);
// base V2 to ext v2
let now = SystemTime::now();
let meta = InternalMetaV2Base {
version: InternalMetaV2::VERSION,
fresh_until: now,
created: now,
updated: now,
stale_while_revalidate_sec: 0,
stale_if_error_sec: 0,
};
let binary = meta.serialize().unwrap();
let meta2 = InternalMetaV2::deserialize(&binary).unwrap();
assert_eq!(meta2.version, 2);
assert_eq!(meta.fresh_until, meta2.fresh_until);
assert_eq!(meta.created, meta2.created);
assert_eq!(meta.updated, meta2.updated);
}
}
}
#[derive(Debug)]
pub(crate) struct CacheMetaInner {
// http header and Internal meta have different ways of serialization, so keep them separated
pub(crate) internal: InternalMeta,
pub(crate) header: ResponseHeader,
/// An opaque type map to hold extra information for communication between cache backends
/// and users. This field is **not** garanteed be persistently stored in the cache backend.
pub extensions: Extensions,
}
/// The cacheable response header and cache metadata
#[derive(Debug)]
pub struct CacheMeta(pub(crate) Box<CacheMetaInner>);
impl CacheMeta {
/// Create a [CacheMeta] from the given metadata and the response header
pub fn new(
fresh_until: SystemTime,
created: SystemTime,
stale_while_revalidate_sec: u32,
stale_if_error_sec: u32,
header: ResponseHeader,
) -> CacheMeta {
CacheMeta(Box::new(CacheMetaInner {
internal: InternalMeta {
version: InternalMeta::VERSION,
fresh_until,
created,
updated: created, // created == updated for new meta
stale_while_revalidate_sec,
stale_if_error_sec,
..Default::default()
},
header,
extensions: Extensions::new(),
}))
}
/// When the asset was created/admitted to cache
pub fn created(&self) -> SystemTime {
self.0.internal.created
}
/// The last time the asset was revalidated
///
/// This value will be the same as [Self::created()] if no revalidation ever happens
pub fn updated(&self) -> SystemTime {
self.0.internal.updated
}
/// Is the asset still valid
pub fn is_fresh(&self, time: SystemTime) -> bool {
// NOTE: HTTP cache time resolution is second
self.0.internal.fresh_until >= time
}
/// How long (in seconds) the asset should be fresh since its admission/revalidation
///
/// This is essentially the max-age value (or its equivalence)
pub fn fresh_sec(&self) -> u64 {
// swallow `duration_since` error, assets that are always stale have earlier `fresh_until` than `created`
// practically speaking we can always treat these as 0 ttl
// XXX: return Error if `fresh_until` is much earlier than expected?
self.0
.internal
.fresh_until
.duration_since(self.0.internal.updated)
.map_or(0, |duration| duration.as_secs())
}
/// Until when the asset is considered fresh
pub fn fresh_until(&self) -> SystemTime {
self.0.internal.fresh_until
}
/// How old the asset is since its admission/revalidation
pub fn age(&self) -> Duration {
SystemTime::now()
.duration_since(self.updated())
.unwrap_or_default()
}
/// The stale-while-revalidate limit in seconds
pub fn stale_while_revalidate_sec(&self) -> u32 {
self.0.internal.stale_while_revalidate_sec
}
/// The stale-if-error limit in seconds
pub fn stale_if_error_sec(&self) -> u32 {
self.0.internal.stale_if_error_sec
}
/// Can the asset be used to serve stale during revalidation at the given time.
///
/// NOTE: the serve stale functions do not check !is_fresh(time),
/// i.e. the object is already assumed to be stale.
pub fn serve_stale_while_revalidate(&self, time: SystemTime) -> bool {
self.can_serve_stale(self.0.internal.stale_while_revalidate_sec, time)
}
/// Can the asset be used to serve stale after error at the given time.
///
/// NOTE: the serve stale functions do not check !is_fresh(time),
/// i.e. the object is already assumed to be stale.
pub fn serve_stale_if_error(&self, time: SystemTime) -> bool {
self.can_serve_stale(self.0.internal.stale_if_error_sec, time)
}
/// Disable serve stale for this asset
pub fn disable_serve_stale(&mut self) {
self.0.internal.stale_if_error_sec = 0;
self.0.internal.stale_while_revalidate_sec = 0;
}
/// Get the variance hash of this asset
pub fn variance(&self) -> Option<HashBinary> {
self.0.internal.variance
}
/// Set the variance key of this asset
pub fn set_variance_key(&mut self, variance_key: HashBinary) {
self.0.internal.variance = Some(variance_key);
}
/// Set the variance (hash) of this asset
pub fn set_variance(&mut self, variance: HashBinary) {
self.0.internal.variance = Some(variance)
}
/// Removes the variance (hash) of this asset
pub fn remove_variance(&mut self) {
self.0.internal.variance = None
}
/// Get the response header in this asset
pub fn response_header(&self) -> &ResponseHeader {
&self.0.header
}
/// Modify the header in this asset
pub fn response_header_mut(&mut self) -> &mut ResponseHeader {
&mut self.0.header
}
/// Expose the extensions to read
pub fn extensions(&self) -> &Extensions {
&self.0.extensions
}
/// Expose the extensions to modify
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.0.extensions
}
/// Get a copy of the response header
pub fn response_header_copy(&self) -> ResponseHeader {
self.0.header.clone()
}
/// get all the headers of this asset
pub fn headers(&self) -> &HMap {
&self.0.header.headers
}
fn can_serve_stale(&self, serve_stale_sec: u32, time: SystemTime) -> bool {
if serve_stale_sec == 0 {
return false;
}
if let Some(stale_until) = self
.0
.internal
.fresh_until
.checked_add(Duration::from_secs(serve_stale_sec.into()))
{
stale_until >= time
} else {
// overflowed: treat as infinite ttl
true
}
}
/// Serialize this object
pub fn serialize(&self) -> Result<(Vec<u8>, Vec<u8>)> {
let internal = self.0.internal.serialize()?;
let header = header_serialize(&self.0.header)?;
Ok((internal, header))
}
/// Deserialize from the binary format
pub fn deserialize(internal: &[u8], header: &[u8]) -> Result<Self> {
let internal = internal_meta::deserialize(internal)?;
let header = header_deserialize(header)?;
Ok(CacheMeta(Box::new(CacheMetaInner {
internal,
header,
extensions: Extensions::new(),
})))
}
}
use http::StatusCode;
/// The function to generate TTL from the given [StatusCode].
pub type FreshSecByStatusFn = fn(StatusCode) -> Option<u32>;
/// The default settings to generate [CacheMeta]
pub struct CacheMetaDefaults {
// if a status code is not included in fresh_sec, it's not considered cacheable by default.
fresh_sec_fn: FreshSecByStatusFn,
stale_while_revalidate_sec: u32,
// TODO: allow "error" condition to be configurable?
stale_if_error_sec: u32,
}
impl CacheMetaDefaults {
/// Create a new [CacheMetaDefaults]
pub const fn new(
fresh_sec_fn: FreshSecByStatusFn,
stale_while_revalidate_sec: u32,
stale_if_error_sec: u32,
) -> Self {
CacheMetaDefaults {
fresh_sec_fn,
stale_while_revalidate_sec,
stale_if_error_sec,
}
}
/// Return the default TTL for the given [StatusCode]
///
/// `None`: do no cache this code.
pub fn fresh_sec(&self, resp_status: StatusCode) -> Option<u32> {
// safe guard to make sure 304 response to share the same default ttl of 200
if resp_status == StatusCode::NOT_MODIFIED {
(self.fresh_sec_fn)(StatusCode::OK)
} else {
(self.fresh_sec_fn)(resp_status)
}
}
/// The default SWR seconds
pub fn serve_stale_while_revalidate_sec(&self) -> u32 {
self.stale_while_revalidate_sec
}
/// The default SIE seconds
pub fn serve_stale_if_error_sec(&self) -> u32 {
self.stale_if_error_sec
}
}
use log::warn;
use once_cell::sync::{Lazy, OnceCell};
use pingora_header_serde::HeaderSerde;
use std::fs::File;
use std::io::Read;
/* load header compression engine and its' dictionary globally */
pub(crate) static COMPRESSION_DICT_PATH: OnceCell<String> = OnceCell::new();
fn load_file(path: &String) -> Option<Vec<u8>> {
let mut file = File::open(path)
.map_err(|e| {
warn!(
"failed to open header compress dictionary file at {}, {:?}",
path, e
);
e
})
.ok()?;
let mut dict = Vec::new();
file.read_to_end(&mut dict)
.map_err(|e| {
warn!(
"failed to read header compress dictionary file at {}, {:?}",
path, e
);
e
})
.ok()?;
Some(dict)
}
static HEADER_SERDE: Lazy<HeaderSerde> = Lazy::new(|| {
let dict = COMPRESSION_DICT_PATH.get().and_then(load_file);
HeaderSerde::new(dict)
});
pub(crate) fn header_serialize(header: &ResponseHeader) -> Result<Vec<u8>> {
HEADER_SERDE.serialize(header)
}
pub(crate) fn header_deserialize<T: AsRef<[u8]>>(buf: T) -> Result<ResponseHeader> {
HEADER_SERDE.deserialize(buf.as_ref())
}

View file

@ -0,0 +1,228 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Cacheability Predictor
use crate::hashtable::{ConcurrentLruCache, LruShard};
pub type CustomReasonPredicate = fn(&'static str) -> bool;
/// Cacheability Predictor
///
/// Remembers previously uncacheable assets.
/// Allows bypassing cache / cache lock early based on historical precedent.
///
/// NOTE: to simply avoid caching requests with certain characteristics,
/// add checks in request_cache_filter to avoid enabling cache in the first place.
/// The predictor's bypass mechanism handles cases where the request _looks_ cacheable
/// but its previous responses suggest otherwise. The request _could_ be cacheable in the future.
pub struct Predictor<const N_SHARDS: usize> {
uncacheable_keys: ConcurrentLruCache<(), N_SHARDS>,
skip_custom_reasons_fn: Option<CustomReasonPredicate>,
}
use crate::{key::CacheHashKey, CacheKey, NoCacheReason};
use log::debug;
/// The cache predictor trait.
///
/// This trait allows user defined predictor to replace [Predictor].
pub trait CacheablePredictor {
/// Return true if likely cacheable, false if likely not.
fn cacheable_prediction(&self, key: &CacheKey) -> bool;
/// Mark cacheable to allow next request to cache.
/// Returns false if the key was already marked cacheable.
fn mark_cacheable(&self, key: &CacheKey) -> bool;
/// Mark uncacheable to actively bypass cache on the next request.
/// May skip marking on certain NoCacheReasons.
/// Returns None if we skipped marking uncacheable.
/// Returns Some(false) if the key was already marked uncacheable.
fn mark_uncacheable(&self, key: &CacheKey, reason: NoCacheReason) -> Option<bool>;
}
// This particular bit of `where [LruShard...; N]: Default` nonsense arises from
// ConcurrentLruCache needing this trait bound, which in turns arises from the Rust
// compiler not being able to guarantee that all array sizes N implement `Default`.
// See https://github.com/rust-lang/rust/issues/61415
impl<const N_SHARDS: usize> Predictor<N_SHARDS>
where
[LruShard<()>; N_SHARDS]: Default,
{
/// Create a new Predictor with `N_SHARDS * shard_capacity` total capacity for
/// uncacheable cache keys.
///
/// - `shard_capacity`: defines number of keys remembered as uncacheable per LRU shard.
/// - `skip_custom_reasons_fn`: an optional predicate used in `mark_uncacheable`
/// that can customize which `Custom` `NoCacheReason`s ought to be remembered as uncacheable.
/// If the predicate returns true, then the predictor will skip remembering the current
/// cache key as uncacheable (and avoid bypassing cache on the next request).
pub fn new(
shard_capacity: usize,
skip_custom_reasons_fn: Option<CustomReasonPredicate>,
) -> Predictor<N_SHARDS> {
Predictor {
uncacheable_keys: ConcurrentLruCache::<(), N_SHARDS>::new(shard_capacity),
skip_custom_reasons_fn,
}
}
}
impl<const N_SHARDS: usize> CacheablePredictor for Predictor<N_SHARDS>
where
[LruShard<()>; N_SHARDS]: Default,
{
fn cacheable_prediction(&self, key: &CacheKey) -> bool {
// variance key is ignored because this check happens before cache lookup
let hash = key.primary_bin();
let key = u128::from_be_bytes(hash); // Endianness doesn't matter
// Note: LRU updated in mark_* functions only,
// as we assume the caller always updates the cacheability of the response later
!self.uncacheable_keys.read(key).contains(&key)
}
fn mark_cacheable(&self, key: &CacheKey) -> bool {
// variance key is ignored because cacheable_prediction() is called before cache lookup
// where the variance key is unknown
let hash = key.primary_bin();
let key = u128::from_be_bytes(hash);
let cache = self.uncacheable_keys.get(key);
if !cache.read().contains(&key) {
// not in uncacheable list, nothing to do
return true;
}
let mut cache = cache.write();
cache.pop(&key);
debug!("bypassed request became cacheable");
false
}
fn mark_uncacheable(&self, key: &CacheKey, reason: NoCacheReason) -> Option<bool> {
// only mark as uncacheable for the future on certain reasons,
// (e.g. InternalErrors)
use NoCacheReason::*;
match reason {
// CacheLockGiveUp: the writer will set OriginNotCache (if applicable)
// readers don't need to do it
NeverEnabled | StorageError | InternalError | Deferred | CacheLockGiveUp
| CacheLockTimeout => {
return None;
}
// Skip certain NoCacheReason::Custom according to user
Custom(reason) if self.skip_custom_reasons_fn.map_or(false, |f| f(reason)) => {
return None;
}
Custom(_) | OriginNotCache | ResponseTooLarge => { /* mark uncacheable for these only */
}
}
// variance key is ignored because cacheable_prediction() is called before cache lookup
// where the variance key is unknown
let hash = key.primary_bin();
let key = u128::from_be_bytes(hash);
let mut cache = self.uncacheable_keys.get(key).write();
// put() returns Some(old_value) if the key existed, else None
let new_key = cache.put(key, ()).is_none();
if new_key {
debug!("request marked uncacheable");
}
Some(new_key)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mark_cacheability() {
let predictor = Predictor::<1>::new(10, None);
let key = CacheKey::new("a", "b", "c");
// cacheable if no history
assert!(predictor.cacheable_prediction(&key));
// don't remember internal / storage errors
predictor.mark_uncacheable(&key, NoCacheReason::InternalError);
assert!(predictor.cacheable_prediction(&key));
predictor.mark_uncacheable(&key, NoCacheReason::StorageError);
assert!(predictor.cacheable_prediction(&key));
// origin explicitly said uncacheable
predictor.mark_uncacheable(&key, NoCacheReason::OriginNotCache);
assert!(!predictor.cacheable_prediction(&key));
// mark cacheable again
predictor.mark_cacheable(&key);
assert!(predictor.cacheable_prediction(&key));
}
#[test]
fn test_custom_skip_predicate() {
let predictor = Predictor::<1>::new(
10,
Some(|custom_reason| matches!(custom_reason, "Skipping")),
);
let key = CacheKey::new("a", "b", "c");
// cacheable if no history
assert!(predictor.cacheable_prediction(&key));
// custom predicate still uses default skip reasons
predictor.mark_uncacheable(&key, NoCacheReason::InternalError);
assert!(predictor.cacheable_prediction(&key));
// other custom reasons can still be marked uncacheable
predictor.mark_uncacheable(&key, NoCacheReason::Custom("DontCacheMe"));
assert!(!predictor.cacheable_prediction(&key));
let key = CacheKey::new("a", "c", "d");
assert!(predictor.cacheable_prediction(&key));
// specific custom reason is skipped
predictor.mark_uncacheable(&key, NoCacheReason::Custom("Skipping"));
assert!(predictor.cacheable_prediction(&key));
}
#[test]
fn test_mark_uncacheable_lru() {
let predictor = Predictor::<1>::new(3, None);
let key1 = CacheKey::new("a", "b", "c");
predictor.mark_uncacheable(&key1, NoCacheReason::OriginNotCache);
assert!(!predictor.cacheable_prediction(&key1));
let key2 = CacheKey::new("a", "bc", "c");
predictor.mark_uncacheable(&key2, NoCacheReason::OriginNotCache);
assert!(!predictor.cacheable_prediction(&key2));
let key3 = CacheKey::new("a", "cd", "c");
predictor.mark_uncacheable(&key3, NoCacheReason::OriginNotCache);
assert!(!predictor.cacheable_prediction(&key3));
// promote / reinsert key1
predictor.mark_uncacheable(&key1, NoCacheReason::OriginNotCache);
let key4 = CacheKey::new("a", "de", "c");
predictor.mark_uncacheable(&key4, NoCacheReason::OriginNotCache);
assert!(!predictor.cacheable_prediction(&key4));
// key 1 was recently used
assert!(!predictor.cacheable_prediction(&key1));
// key 2 was evicted
assert!(predictor.cacheable_prediction(&key2));
assert!(!predictor.cacheable_prediction(&key3));
assert!(!predictor.cacheable_prediction(&key4));
}
}

754
pingora-cache/src/put.rs Normal file
View file

@ -0,0 +1,754 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Cache Put module
use crate::*;
use bytes::Bytes;
use http::header;
use pingora_core::protocols::http::{
v1::common::header_value_content_length, HttpTask, ServerSession,
};
/// The interface to define cache put behavior
pub trait CachePut {
/// Return whether to cache the asset according to the given response header.
fn cacheable(&self, response: &ResponseHeader) -> RespCacheable {
let cc = cache_control::CacheControl::from_resp_headers(response);
filters::resp_cacheable(cc.as_ref(), response, false, Self::cache_defaults())
}
/// Return the [CacheMetaDefaults]
fn cache_defaults() -> &'static CacheMetaDefaults;
}
use parse_response::ResponseParse;
/// The cache put context
pub struct CachePutCtx<C: CachePut> {
cache_put: C, // the user defined cache put behavior
key: CacheKey,
storage: &'static (dyn storage::Storage + Sync), // static for now
eviction: Option<&'static (dyn eviction::EvictionManager + Sync)>,
miss_handler: Option<MissHandler>,
max_file_size_bytes: Option<usize>,
meta: Option<CacheMeta>,
parser: ResponseParse,
// FIXME: cache put doesn't have cache lock but some storage cannot handle concurrent put
// to the same asset.
trace: trace::Span,
}
impl<C: CachePut> CachePutCtx<C> {
/// Create a new [CachePutCtx]
pub fn new(
cache_put: C,
key: CacheKey,
storage: &'static (dyn storage::Storage + Sync),
eviction: Option<&'static (dyn eviction::EvictionManager + Sync)>,
trace: trace::Span,
) -> Self {
CachePutCtx {
cache_put,
key,
storage,
eviction,
miss_handler: None,
max_file_size_bytes: None,
meta: None,
parser: ResponseParse::new(),
trace,
}
}
/// Set the max cacheable size limit
pub fn set_max_file_size_bytes(&mut self, max_file_size_bytes: usize) {
self.max_file_size_bytes = Some(max_file_size_bytes);
}
async fn put_header(&mut self, meta: CacheMeta) -> Result<()> {
let trace = self.trace.child("cache put header", |o| o.start()).handle();
let miss_handler = self
.storage
.get_miss_handler(&self.key, &meta, &trace)
.await?;
self.miss_handler = Some(
if let Some(max_file_size_bytes) = self.max_file_size_bytes {
Box::new(MaxFileSizeMissHandler::new(
miss_handler,
max_file_size_bytes,
))
} else {
miss_handler
},
);
self.meta = Some(meta);
Ok(())
}
async fn put_body(&mut self, data: Bytes, eof: bool) -> Result<()> {
let miss_handler = self.miss_handler.as_mut().unwrap();
miss_handler.write_body(data, eof).await
}
async fn finish(&mut self) -> Result<()> {
let Some(miss_handler) = self.miss_handler.take() else {
// no miss_handler, uncacheable
return Ok(());
};
let size = miss_handler.finish().await?;
if let Some(eviction) = self.eviction.as_ref() {
let cache_key = self.key.to_compact();
let meta = self.meta.as_ref().unwrap();
let evicted = eviction.admit(cache_key, size, meta.0.internal.fresh_until);
// TODO: make this async
let trace = self
.trace
.child("cache put eviction", |o| o.start())
.handle();
for item in evicted {
// TODO: warn/log the error
let _ = self.storage.purge(&item, &trace).await;
}
}
Ok(())
}
async fn do_cache_put(&mut self, data: &[u8]) -> Result<Option<NoCacheReason>> {
let tasks = self.parser.inject_data(data)?;
for task in tasks {
match task {
HttpTask::Header(header, _eos) => match self.cache_put.cacheable(&header) {
RespCacheable::Cacheable(meta) => {
if let Some(max_file_size_bytes) = self.max_file_size_bytes {
let content_length_hdr = header.headers.get(header::CONTENT_LENGTH);
if let Some(content_length) =
header_value_content_length(content_length_hdr)
{
if content_length > max_file_size_bytes {
return Ok(Some(NoCacheReason::ResponseTooLarge));
}
}
}
self.put_header(meta).await?;
}
RespCacheable::Uncacheable(reason) => {
return Ok(Some(reason));
}
},
HttpTask::Body(data, eos) => {
if let Some(data) = data {
self.put_body(data, eos).await?;
}
}
_ => {
panic!("unexpected HttpTask during cache put {task:?}");
}
}
}
Ok(None)
}
/// Start the cache put logic for the given request
///
/// This function will start to read the request body to put into cache.
/// Return:
/// - `Ok(None)` when the payload will be cache.
/// - `Ok(Some(reason))` when the payload is not cacheable
pub async fn cache_put(
&mut self,
session: &mut ServerSession,
) -> Result<Option<NoCacheReason>> {
let mut no_cache_reason = None;
while let Some(data) = session.read_request_body().await? {
if no_cache_reason.is_some() {
// even uncacheable, the entire body needs to be drains for 1. downstream
// not throwing errors 2. connection reuse
continue;
}
no_cache_reason = self.do_cache_put(&data).await?
}
self.parser.finish()?;
self.finish().await?;
Ok(no_cache_reason)
}
}
#[cfg(test)]
mod test {
use super::*;
use once_cell::sync::Lazy;
use rustracing::span::Span;
struct TestCachePut();
impl CachePut for TestCachePut {
fn cache_defaults() -> &'static CacheMetaDefaults {
const DEFAULT: CacheMetaDefaults = CacheMetaDefaults::new(|_| Some(1), 1, 1);
&DEFAULT
}
}
type TestCachePutCtx = CachePutCtx<TestCachePut>;
static CACHE_BACKEND: Lazy<MemCache> = Lazy::new(MemCache::new);
#[tokio::test]
async fn test_cache_put() {
let key = CacheKey::new("", "a", "1");
let span = Span::inactive();
let put = TestCachePut();
let mut ctx = TestCachePutCtx::new(put, key.clone(), &*CACHE_BACKEND, None, span);
let payload = b"HTTP/1.1 200 OK\r\n\
Date: Thu, 26 Apr 2018 05:42:05 GMT\r\n\
Content-Type: text/html; charset=utf-8\r\n\
Connection: keep-alive\r\n\
X-Frame-Options: SAMEORIGIN\r\n\
Cache-Control: public, max-age=1\r\n\
Server: origin-server\r\n\
Content-Length: 4\r\n\r\nrust";
// here we skip mocking a real http session for simplicity
let res = ctx.do_cache_put(payload).await.unwrap();
assert!(res.is_none()); // cacheable
ctx.parser.finish().unwrap();
ctx.finish().await.unwrap();
let span = Span::inactive();
let (meta, mut hit) = CACHE_BACKEND
.lookup(&key, &span.handle())
.await
.unwrap()
.unwrap();
assert_eq!(
meta.headers().get("date").unwrap(),
"Thu, 26 Apr 2018 05:42:05 GMT"
);
let data = hit.read_body().await.unwrap().unwrap();
assert_eq!(data, "rust");
}
#[tokio::test]
async fn test_cache_put_uncacheable() {
let key = CacheKey::new("", "a", "1");
let span = Span::inactive();
let put = TestCachePut();
let mut ctx = TestCachePutCtx::new(put, key.clone(), &*CACHE_BACKEND, None, span);
let payload = b"HTTP/1.1 200 OK\r\n\
Date: Thu, 26 Apr 2018 05:42:05 GMT\r\n\
Content-Type: text/html; charset=utf-8\r\n\
Connection: keep-alive\r\n\
X-Frame-Options: SAMEORIGIN\r\n\
Cache-Control: no-store\r\n\
Server: origin-server\r\n\
Content-Length: 4\r\n\r\nrust";
// here we skip mocking a real http session for simplicity
let no_cache = ctx.do_cache_put(payload).await.unwrap().unwrap();
assert_eq!(no_cache, NoCacheReason::OriginNotCache);
ctx.parser.finish().unwrap();
ctx.finish().await.unwrap();
}
}
// maybe this can simplify some logic in pingora::h1
mod parse_response {
use super::*;
use bytes::{Bytes, BytesMut};
use httparse::Status;
use pingora_error::{
Error,
ErrorType::{self, *},
Result,
};
use pingora_http::ResponseHeader;
pub const INVALID_CHUNK: ErrorType = ErrorType::new("InvalidChunk");
pub const INCOMPLETE_BODY: ErrorType = ErrorType::new("IncompleteHttpBody");
const MAX_HEADERS: usize = 256;
const INIT_HEADER_BUF_SIZE: usize = 4096;
const CHUNK_DELIMITER_SIZE: usize = 2; // \r\n
#[derive(Debug, Clone, Copy)]
enum ParseState {
Init,
PartialHeader,
PartialBodyContentLength(usize, usize),
PartialChunkedBody(usize),
PartialHttp10Body(usize),
Done(usize),
Invalid(httparse::Error),
}
impl ParseState {
fn is_done(&self) -> bool {
matches!(self, Self::Done(_))
}
fn read_header(&self) -> bool {
matches!(self, Self::Init | Self::PartialHeader)
}
fn read_body(&self) -> bool {
matches!(
self,
Self::PartialBodyContentLength(..)
| Self::PartialChunkedBody(_)
| Self::PartialHttp10Body(_)
)
}
}
pub(super) struct ResponseParse {
state: ParseState,
buf: BytesMut,
header_bytes: Bytes,
}
impl ResponseParse {
pub fn new() -> Self {
ResponseParse {
state: ParseState::Init,
buf: BytesMut::with_capacity(INIT_HEADER_BUF_SIZE),
header_bytes: Bytes::new(),
}
}
pub fn inject_data(&mut self, data: &[u8]) -> Result<Vec<HttpTask>> {
self.put_data(data);
let mut tasks = vec![];
while !self.state.is_done() {
if self.state.read_header() {
let header = self.parse_header()?;
let Some(header) = header else {
break;
};
tasks.push(HttpTask::Header(Box::new(header), self.state.is_done()));
} else if self.state.read_body() {
let body = self.parse_body()?;
let Some(body) = body else {
break;
};
tasks.push(HttpTask::Body(Some(body), self.state.is_done()));
} else {
break;
}
}
Ok(tasks)
}
fn put_data(&mut self, data: &[u8]) {
use ParseState::*;
if matches!(self.state, Done(_) | Invalid(_)) {
panic!("Wrong phase {:?}", self.state);
}
self.buf.extend_from_slice(data);
}
fn parse_header(&mut self) -> Result<Option<ResponseHeader>> {
let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut resp = httparse::Response::new(&mut headers);
let mut parser = httparse::ParserConfig::default();
parser.allow_spaces_after_header_name_in_responses(true);
parser.allow_obsolete_multiline_headers_in_responses(true);
let res = parser.parse_response(&mut resp, &self.buf);
let res = match res {
Ok(res) => res,
Err(e) => {
self.state = ParseState::Invalid(e);
return Error::e_because(
InvalidHTTPHeader,
format!("buf: {:?}", String::from_utf8_lossy(&self.buf)),
e,
);
}
};
let split_to = match res {
Status::Complete(s) => s,
Status::Partial => {
self.state = ParseState::PartialHeader;
return Ok(None);
}
};
// safe to unwrap, valid response always has code set.
let mut response =
ResponseHeader::build(resp.code.unwrap(), Some(resp.headers.len())).unwrap();
for header in resp.headers {
// TODO: consider hold a Bytes and all header values can be Bytes referencing the
// original buffer without reallocation
response.append_header(header.name.to_owned(), header.value.to_owned())?;
}
// TODO: see above, we can make header value `Bytes` referencing header_bytes
let header_bytes = self.buf.split_to(split_to).freeze();
self.header_bytes = header_bytes;
self.state = body_type(&response);
Ok(Some(response))
}
fn parse_body(&mut self) -> Result<Option<Bytes>> {
use ParseState::*;
if self.buf.is_empty() {
return Ok(None);
}
match self.state {
Init | PartialHeader | Invalid(_) => {
panic!("Wrong phase {:?}", self.state);
}
Done(_) => Ok(None),
PartialBodyContentLength(total, mut seen) => {
let end = if total < self.buf.len() + seen {
// TODO: warn! more data than expected
total - seen
} else {
self.buf.len()
};
seen += end;
if seen >= total {
self.state = Done(seen);
} else {
self.state = PartialBodyContentLength(total, seen);
}
Ok(Some(self.buf.split_to(end).freeze()))
}
PartialChunkedBody(seen) => {
let parsed = httparse::parse_chunk_size(&self.buf).map_err(|e| {
self.state = Done(seen);
Error::explain(INVALID_CHUNK, format!("Invalid chucked encoding: {e:?}"))
})?;
match parsed {
httparse::Status::Complete((header_len, body_len)) => {
// 4\r\nRust\r\n: header: "4\r\n", body: "Rust", "\r\n"
let total_chunk_size =
header_len + body_len as usize + CHUNK_DELIMITER_SIZE;
if self.buf.len() < total_chunk_size {
// wait for the full chunk tob read
// Note that we have to buffer the entire chunk in this design
Ok(None)
} else {
if body_len == 0 {
self.state = Done(seen);
} else {
self.state = PartialChunkedBody(seen + body_len as usize);
}
let mut chunk_bytes = self.buf.split_to(total_chunk_size);
let mut chunk_body = chunk_bytes.split_off(header_len);
chunk_body.truncate(body_len as usize);
// Note that the final 0 sized chunk will return an empty Bytes
// instead of not None
Ok(Some(chunk_body.freeze()))
}
}
httparse::Status::Partial => {
// not even a full chunk, continue waiting for more data
Ok(None)
}
}
}
PartialHttp10Body(seen) => {
self.state = PartialHttp10Body(seen + self.buf.len());
Ok(Some(self.buf.split().freeze()))
}
}
}
pub fn finish(&mut self) -> Result<()> {
if let ParseState::PartialHttp10Body(seen) = self.state {
self.state = ParseState::Done(seen);
}
if !self.state.is_done() {
Error::e_explain(INCOMPLETE_BODY, format!("{:?}", self.state))
} else {
Ok(())
}
}
}
fn body_type(resp: &ResponseHeader) -> ParseState {
use http::StatusCode;
if matches!(
resp.status,
StatusCode::NO_CONTENT | StatusCode::NOT_MODIFIED
) {
// these status code cannot have body by definition
return ParseState::Done(0);
}
if let Some(encoding) = resp.headers.get(http::header::TRANSFER_ENCODING) {
// TODO: case sensitive?
if encoding.as_bytes() == b"chunked" {
return ParseState::PartialChunkedBody(0);
}
}
if let Some(cl) = resp.headers.get(http::header::CONTENT_LENGTH) {
// ignore invalid header value
if let Some(cl) = std::str::from_utf8(cl.as_bytes())
.ok()
.and_then(|cl| cl.parse::<usize>().ok())
{
return if cl == 0 {
ParseState::Done(0)
} else {
ParseState::PartialBodyContentLength(cl, 0)
};
}
}
ParseState::PartialHttp10Body(0)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_basic_response() {
let input = b"HTTP/1.1 200 OK\r\n\r\n";
let mut parser = ResponseParse::new();
let output = parser.inject_data(input).unwrap();
assert_eq!(output.len(), 1);
let HttpTask::Header(header, eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(header.status, 200);
assert!(!eos);
let body = b"abc";
let output = parser.inject_data(body).unwrap();
assert_eq!(output.len(), 1);
let HttpTask::Body(data, _eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(data.as_ref().unwrap(), &body[..]);
parser.finish().unwrap();
}
#[test]
fn test_partial_response_headers() {
let input = b"HTTP/1.1 200 OK\r\n";
let mut parser = ResponseParse::new();
let output = parser.inject_data(input).unwrap();
// header is not complete
assert_eq!(output.len(), 0);
let output = parser
.inject_data("Server: pingora\r\n\r\n".as_bytes())
.unwrap();
assert_eq!(output.len(), 1);
let HttpTask::Header(header, eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(header.status, 200);
assert_eq!(header.headers.get("Server").unwrap(), "pingora");
assert!(!eos);
}
#[test]
fn test_invalid_headers() {
let input = b"HTP/1.1 200 OK\r\nServer: pingora\r\n\r\n";
let mut parser = ResponseParse::new();
let output = parser.inject_data(input);
// header is not complete
assert!(output.is_err());
}
#[test]
fn test_body_content_length() {
let input = b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nabc";
let mut parser = ResponseParse::new();
let output = parser.inject_data(input).unwrap();
assert_eq!(output.len(), 2);
let HttpTask::Header(header, _eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(header.status, 200);
let HttpTask::Body(data, eos) = &output[1] else {
panic!("{:?}", output);
};
assert_eq!(data.as_ref().unwrap(), "abc");
assert!(!eos);
let output = parser.inject_data(b"def").unwrap();
assert_eq!(output.len(), 1);
let HttpTask::Body(data, eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(data.as_ref().unwrap(), "def");
assert!(eos);
parser.finish().unwrap();
}
#[test]
fn test_body_chunked() {
let input = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nrust\r\n";
let mut parser = ResponseParse::new();
let output = parser.inject_data(input).unwrap();
assert_eq!(output.len(), 2);
let HttpTask::Header(header, _eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(header.status, 200);
let HttpTask::Body(data, eos) = &output[1] else {
panic!("{:?}", output);
};
assert_eq!(data.as_ref().unwrap(), "rust");
assert!(!eos);
let output = parser.inject_data(b"0\r\n\r\n").unwrap();
assert_eq!(output.len(), 1);
let HttpTask::Body(data, eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(data.as_ref().unwrap(), "");
assert!(eos);
parser.finish().unwrap();
}
#[test]
fn test_body_content_length_early() {
let input = b"HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nabc";
let mut parser = ResponseParse::new();
let output = parser.inject_data(input).unwrap();
assert_eq!(output.len(), 2);
let HttpTask::Header(header, _eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(header.status, 200);
let HttpTask::Body(data, eos) = &output[1] else {
panic!("{:?}", output);
};
assert_eq!(data.as_ref().unwrap(), "abc");
assert!(!eos);
parser.finish().unwrap_err();
}
#[test]
fn test_body_content_length_more_data() {
let input = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nabc";
let mut parser = ResponseParse::new();
let output = parser.inject_data(input).unwrap();
assert_eq!(output.len(), 2);
let HttpTask::Header(header, _eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(header.status, 200);
let HttpTask::Body(data, eos) = &output[1] else {
panic!("{:?}", output);
};
assert_eq!(data.as_ref().unwrap(), "ab");
assert!(eos);
// extra data is dropped without error
parser.finish().unwrap();
}
#[test]
fn test_body_chunked_early() {
let input = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nrust\r\n";
let mut parser = ResponseParse::new();
let output = parser.inject_data(input).unwrap();
assert_eq!(output.len(), 2);
let HttpTask::Header(header, _eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(header.status, 200);
let HttpTask::Body(data, eos) = &output[1] else {
panic!("{:?}", output);
};
assert_eq!(data.as_ref().unwrap(), "rust");
assert!(!eos);
parser.finish().unwrap_err();
}
#[test]
fn test_body_chunked_partial_chunk() {
let input = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nru";
let mut parser = ResponseParse::new();
let output = parser.inject_data(input).unwrap();
assert_eq!(output.len(), 1);
let HttpTask::Header(header, _eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(header.status, 200);
let output = parser.inject_data(b"st\r\n").unwrap();
assert_eq!(output.len(), 1);
let HttpTask::Body(data, eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(data.as_ref().unwrap(), "rust");
assert!(!eos);
}
#[test]
fn test_body_chunked_partial_chunk_head() {
let input = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r";
let mut parser = ResponseParse::new();
let output = parser.inject_data(input).unwrap();
assert_eq!(output.len(), 1);
let HttpTask::Header(header, _eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(header.status, 200);
let output = parser.inject_data(b"\nrust\r\n").unwrap();
assert_eq!(output.len(), 1);
let HttpTask::Body(data, eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(data.as_ref().unwrap(), "rust");
assert!(!eos);
}
#[test]
fn test_body_chunked_many_chunks() {
let input =
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nrust\r\n1\r\ny\r\n";
let mut parser = ResponseParse::new();
let output = parser.inject_data(input).unwrap();
assert_eq!(output.len(), 3);
let HttpTask::Header(header, _eos) = &output[0] else {
panic!("{:?}", output);
};
assert_eq!(header.status, 200);
let HttpTask::Body(data, eos) = &output[1] else {
panic!("{:?}", output);
};
assert!(!eos);
assert_eq!(data.as_ref().unwrap(), "rust");
let HttpTask::Body(data, eos) = &output[2] else {
panic!("{:?}", output);
};
assert_eq!(data.as_ref().unwrap(), "y");
assert!(!eos);
}
}
}

View file

@ -0,0 +1,122 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Cache backend storage abstraction
use super::{CacheKey, CacheMeta};
use crate::key::CompactCacheKey;
use crate::trace::SpanHandle;
use async_trait::async_trait;
use pingora_error::Result;
use std::any::Any;
/// Cache storage interface
#[async_trait]
pub trait Storage {
// TODO: shouldn't have to be static
/// Lookup the storage for the given [CacheKey]
async fn lookup(
&'static self,
key: &CacheKey,
trace: &SpanHandle,
) -> Result<Option<(CacheMeta, HitHandler)>>;
/// Write the given [CacheMeta] to the storage. Return [MissHandler] to write the body later.
async fn get_miss_handler(
&'static self,
key: &CacheKey,
meta: &CacheMeta,
trace: &SpanHandle,
) -> Result<MissHandler>;
/// Delete the cached asset for the given key
///
/// [CompactCacheKey] is used here because it is how eviction managers store the keys
async fn purge(&'static self, key: &CompactCacheKey, trace: &SpanHandle) -> Result<bool>;
/// Update cache header and metadata for the already stored asset.
async fn update_meta(
&'static self,
key: &CacheKey,
meta: &CacheMeta,
trace: &SpanHandle,
) -> Result<bool>;
/// Whether this storage backend supports reading partially written data
///
/// This is to indicate when cache should unlock readers
fn support_streaming_partial_write(&self) -> bool {
false
}
/// Helper function to cast the trait object to concrete types
fn as_any(&self) -> &(dyn Any + Send + Sync + 'static);
}
/// Cache hit handling trait
#[async_trait]
pub trait HandleHit {
/// Read cached body
///
/// Return `None` when no more body to read.
async fn read_body(&mut self) -> Result<Option<bytes::Bytes>>;
/// Finish the current cache hit
async fn finish(
self: Box<Self>, // because self is always used as a trait object
storage: &'static (dyn Storage + Sync),
key: &CacheKey,
trace: &SpanHandle,
) -> Result<()>;
/// Whether this storage allow seeking to a certain range of body
fn can_seek(&self) -> bool {
false
}
/// Try to seek to a certain range of the body
///
/// `end: None` means to read to the end of the body.
fn seek(&mut self, _start: usize, _end: Option<usize>) -> Result<()> {
// to prevent impl can_seek() without impl seek
todo!("seek() needs to be implemented")
}
// TODO: fn is_stream_hit()
/// Helper function to cast the trait object to concrete types
fn as_any(&self) -> &(dyn Any + Send + Sync);
}
/// Hit Handler
pub type HitHandler = Box<(dyn HandleHit + Sync + Send)>;
/// Cache miss handling trait
#[async_trait]
pub trait HandleMiss {
/// Write the given body to the storage
async fn write_body(&mut self, data: bytes::Bytes, eof: bool) -> Result<()>;
/// Finish the cache admission
///
/// When `self` is dropped without calling this function, the storage should consider this write
/// failed.
async fn finish(
self: Box<Self>, // because self is always used as a trait object
) -> Result<usize>;
}
/// Miss Handler
pub type MissHandler = Box<(dyn HandleMiss + Sync + Send)>;

View file

@ -0,0 +1,98 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Distributed tracing helpers
use rustracing_jaeger::span::SpanContextState;
use std::time::SystemTime;
use crate::{CacheMeta, CachePhase, HitStatus};
pub use rustracing::tag::Tag;
pub type Span = rustracing::span::Span<SpanContextState>;
pub type SpanHandle = rustracing::span::SpanHandle<SpanContextState>;
#[derive(Debug)]
pub(crate) struct CacheTraceCTX {
// parent span
pub cache_span: Span,
// only spans across multiple calls need to store here
pub miss_span: Span,
pub hit_span: Span,
}
impl CacheTraceCTX {
pub fn new() -> Self {
CacheTraceCTX {
cache_span: Span::inactive(),
miss_span: Span::inactive(),
hit_span: Span::inactive(),
}
}
pub fn enable(&mut self, cache_span: Span) {
self.cache_span = cache_span;
}
#[inline]
pub fn child(&self, name: &'static str) -> Span {
self.cache_span.child(name, |o| o.start())
}
pub fn start_miss_span(&mut self) {
self.miss_span = self.child("miss");
}
pub fn get_miss_span(&self) -> SpanHandle {
self.miss_span.handle()
}
pub fn finish_miss_span(&mut self) {
self.miss_span.set_finish_time(SystemTime::now);
}
pub fn start_hit_span(&mut self, phase: CachePhase, hit_status: HitStatus) {
self.hit_span = self.child("hit");
self.hit_span.set_tag(|| Tag::new("phase", phase.as_str()));
self.hit_span
.set_tag(|| Tag::new("status", hit_status.as_str()));
}
pub fn finish_hit_span(&mut self) {
self.hit_span.set_finish_time(SystemTime::now);
}
pub fn log_meta(&mut self, meta: &CacheMeta) {
fn ts2epoch(ts: SystemTime) -> f64 {
ts.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default() // should never overflow but be safe here
.as_secs_f64()
}
let internal = &meta.0.internal;
self.hit_span.set_tags(|| {
[
Tag::new("created", ts2epoch(internal.created)),
Tag::new("fresh_until", ts2epoch(internal.fresh_until)),
Tag::new("updated", ts2epoch(internal.updated)),
Tag::new("stale_if_error_sec", internal.stale_if_error_sec as i64),
Tag::new(
"stale_while_revalidate_sec",
internal.stale_while_revalidate_sec as i64,
),
Tag::new("variance", internal.variance.is_some()),
]
});
}
}

View file

@ -0,0 +1,120 @@
use std::{borrow::Cow, collections::BTreeMap};
use blake2::Digest;
use crate::key::{Blake2b128, HashBinary};
/// A builder for variance keys, used for distinguishing multiple cached assets
/// at the same URL. This is intended to be easily passed to helper functions,
/// which can each populate a portion of the variance.
pub struct VarianceBuilder<'a> {
values: BTreeMap<Cow<'a, str>, Cow<'a, [u8]>>,
}
impl<'a> VarianceBuilder<'a> {
/// Create an empty variance key. Has no variance by default - add some variance using
/// [`Self::add_value`].
pub fn new() -> Self {
VarianceBuilder {
values: BTreeMap::new(),
}
}
/// Add a byte string to the variance key. Not sensitive to insertion order.
/// `value` is intended to take either `&str` or `&[u8]`.
pub fn add_value(&mut self, name: &'a str, value: &'a (impl AsRef<[u8]> + ?Sized)) {
self.values
.insert(name.into(), Cow::Borrowed(value.as_ref()));
}
/// Move a byte string to the variance key. Not sensitive to insertion order. Useful when
/// writing helper functions which generate a value then add said value to the VarianceBuilder.
/// Without this, the helper function would have to move the value to the calling function
/// to extend its lifetime to at least match the VarianceBuilder.
pub fn add_owned_value(&mut self, name: &'a str, value: Vec<u8>) {
self.values.insert(name.into(), Cow::Owned(value));
}
/// Check whether this variance key actually has variance, or just refers to the root asset
pub fn has_variance(&self) -> bool {
!self.values.is_empty()
}
/// Hash this variance key. Returns [`None`] if [`Self::has_variance`] is false.
pub fn finalize(self) -> Option<HashBinary> {
const SALT: &[u8; 1] = &[0u8; 1];
if self.has_variance() {
let mut hash = Blake2b128::new();
for (name, value) in self.values.iter() {
hash.update(name.as_bytes());
hash.update(SALT);
hash.update(value);
hash.update(SALT);
}
Some(hash.finalize().into())
} else {
None
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_basic() {
let key_empty = VarianceBuilder::new().finalize();
assert_eq!(None, key_empty);
let mut key_value = VarianceBuilder::new();
key_value.add_value("a", "a");
let key_value = key_value.finalize();
let mut key_owned_value = VarianceBuilder::new();
key_owned_value.add_owned_value("a", "a".as_bytes().to_vec());
let key_owned_value = key_owned_value.finalize();
assert_ne!(key_empty, key_value);
assert_ne!(key_empty, key_owned_value);
assert_eq!(key_value, key_owned_value);
}
#[test]
fn test_value_ordering() {
let mut key_abc = VarianceBuilder::new();
key_abc.add_value("a", "a");
key_abc.add_value("b", "b");
key_abc.add_value("c", "c");
let key_abc = key_abc.finalize().unwrap();
let mut key_bac = VarianceBuilder::new();
key_bac.add_value("b", "b");
key_bac.add_value("a", "a");
key_bac.add_value("c", "c");
let key_bac = key_bac.finalize().unwrap();
let mut key_cba = VarianceBuilder::new();
key_cba.add_value("c", "c");
key_cba.add_value("b", "b");
key_cba.add_value("a", "a");
let key_cba = key_cba.finalize().unwrap();
assert_eq!(key_abc, key_bac);
assert_eq!(key_abc, key_cba);
}
#[test]
fn test_value_overriding() {
let mut key_a = VarianceBuilder::new();
key_a.add_value("a", "a");
let key_a = key_a.finalize().unwrap();
let mut key_b = VarianceBuilder::new();
key_b.add_value("a", "b");
key_b.add_value("a", "a");
let key_b = key_b.finalize().unwrap();
assert_eq!(key_a, key_b);
}
}

81
pingora-core/Cargo.toml Normal file
View file

@ -0,0 +1,81 @@
[package]
name = "pingora-core"
version = "0.1.0"
authors = ["Yuchen Wu <yuchen@cloudflare.com>"]
license = "Apache-2.0"
edition = "2021"
repository = "https://github.com/cloudflare/pingora"
categories = ["asynchronous", "network-programming"]
keywords = ["async", "http", "network", "pingora"]
exclude = ["tests/*"]
description = """
Pingora's APIs and traits for the core network protocols.
"""
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "pingora_core"
path = "src/lib.rs"
[dependencies]
pingora-runtime = { version = "0.1.0", path = "../pingora-runtime" }
pingora-openssl = { version = "0.1.0", path = "../pingora-openssl", optional = true }
pingora-boringssl = { version = "0.1.0", path = "../pingora-boringssl", optional = true }
pingora-pool = { version = "0.1.0", path = "../pingora-pool" }
pingora-error = { version = "0.1.0", path = "../pingora-error" }
pingora-timeout = { version = "0.1.0", path = "../pingora-timeout" }
pingora-http = { version = "0.1.0", path = "../pingora-http" }
tokio = { workspace = true, features = ["rt-multi-thread", "signal"] }
futures = "0.3"
async-trait = { workspace = true }
httparse = { workspace = true }
bytes = { workspace = true }
http = { workspace = true }
log = { workspace = true }
h2 = { workspace = true }
lru = { workspace = true }
nix = "0.24"
structopt = "0.3"
once_cell = { workspace = true }
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.8"
libc = "0.2.70"
chrono = { version = "0.4", features = ["alloc"], default-features = false }
thread_local = "1.0"
prometheus = "0.13"
daemonize = "0.5.0"
sentry = { version = "0.26", features = [
"backtrace",
"contexts",
"panic",
"reqwest",
"rustls",
], default-features = false }
regex = "1"
percent-encoding = "2.1"
parking_lot = "0.12"
socket2 = { version = "0", features = ["all"] }
flate2 = { version = "1", features = ["zlib-ng"], default-features = false }
sfv = "0"
rand = "0.8"
ahash = { workspace = true }
unicase = "2"
brotli = "3"
openssl-probe = "0.1"
tokio-test = "0.4"
zstd = "0"
[dev-dependencies]
matches = "0.1"
env_logger = "0.9"
reqwest = { version = "0.11", features = ["rustls"], default-features = false }
hyperlocal = "0.8"
hyper = "0.14"
jemallocator = "0.5"
[features]
default = ["openssl"]
openssl = ["pingora-openssl"]
boringssl = ["pingora-boringssl"]
patched_http1 = []

202
pingora-core/LICENSE Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -0,0 +1,210 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A simple HTTP application trait that maps a request to a response
use async_trait::async_trait;
use http::Response;
use log::{debug, error, trace};
use pingora_http::ResponseHeader;
use std::sync::Arc;
use crate::apps::HttpServerApp;
use crate::modules::http::{HttpModules, ModuleBuilder};
use crate::protocols::http::HttpTask;
use crate::protocols::http::ServerSession;
use crate::protocols::Stream;
use crate::server::ShutdownWatch;
/// This trait defines how to map a request to a response
#[cfg_attr(not(doc_async_trait), async_trait)]
pub trait ServeHttp {
/// Define the mapping from a request to a response.
/// Note that the request header is already read, but the implementation needs to read the
/// request body if any.
///
/// # Limitation
/// In this API, the entire response has to be generated before the end of this call.
/// So it is not suitable for streaming response or interactive communications.
/// Users need to implement their own [`super::HttpServerApp`] for those use cases.
async fn response(&self, http_session: &mut ServerSession) -> Response<Vec<u8>>;
}
// TODO: remove this in favor of HttpServer?
#[cfg_attr(not(doc_async_trait), async_trait)]
impl<SV> HttpServerApp for SV
where
SV: ServeHttp + Send + Sync,
{
async fn process_new_http(
self: &Arc<Self>,
mut http: ServerSession,
shutdown: &ShutdownWatch,
) -> Option<Stream> {
match http.read_request().await {
Ok(res) => match res {
false => {
debug!("Failed to read request header");
return None;
}
true => {
debug!("Successfully get a new request");
}
},
Err(e) => {
error!("HTTP server fails to read from downstream: {e}");
return None;
}
}
trace!("{:?}", http.req_header());
if *shutdown.borrow() {
http.set_keepalive(None);
} else {
http.set_keepalive(Some(60));
}
let new_response = self.response(&mut http).await;
let (parts, body) = new_response.into_parts();
let resp_header: ResponseHeader = parts.into();
match http.write_response_header(Box::new(resp_header)).await {
Ok(()) => {
debug!("HTTP response header done.");
}
Err(e) => {
error!(
"HTTP server fails to write to downstream: {e}, {}",
http.request_summary()
);
}
}
if !body.is_empty() {
// TODO: check if chunked encoding is needed
match http.write_response_body(body.into()).await {
Ok(_) => debug!("HTTP response written."),
Err(e) => error!(
"HTTP server fails to write to downstream: {e}, {}",
http.request_summary()
),
}
}
match http.finish().await {
Ok(c) => c,
Err(e) => {
error!("HTTP server fails to finish the request: {e}");
None
}
}
}
}
/// A helper struct for HTTP server with http modules embedded
pub struct HttpServer<SV> {
app: SV,
modules: HttpModules,
}
impl<SV> HttpServer<SV> {
/// Create a new [HttpServer] with the given app which implements [ServeHttp]
pub fn new_app(app: SV) -> Self {
HttpServer {
app,
modules: HttpModules::new(),
}
}
/// Add [ModuleBuilder] to this [HttpServer]
pub fn add_module(&mut self, module: ModuleBuilder) {
self.modules.add_module(module)
}
}
#[cfg_attr(not(doc_async_trait), async_trait)]
impl<SV> HttpServerApp for HttpServer<SV>
where
SV: ServeHttp + Send + Sync,
{
async fn process_new_http(
self: &Arc<Self>,
mut http: ServerSession,
shutdown: &ShutdownWatch,
) -> Option<Stream> {
match http.read_request().await {
Ok(res) => match res {
false => {
debug!("Failed to read request header");
return None;
}
true => {
debug!("Successfully get a new request");
}
},
Err(e) => {
error!("HTTP server fails to read from downstream: {e}");
return None;
}
}
trace!("{:?}", http.req_header());
if *shutdown.borrow() {
http.set_keepalive(None);
} else {
http.set_keepalive(Some(60));
}
let mut module_ctx = self.modules.build_ctx();
let req = http.req_header_mut();
module_ctx.request_header_filter(req).ok()?;
let new_response = self.app.response(&mut http).await;
let (parts, body) = new_response.into_parts();
let resp_header: ResponseHeader = parts.into();
let mut task = HttpTask::Header(Box::new(resp_header), body.is_empty());
module_ctx.response_filter(&mut task).ok()?;
trace!("{task:?}");
match http.response_duplex_vec(vec![task]).await {
Ok(_) => {
debug!("HTTP response header done.");
}
Err(e) => {
error!(
"HTTP server fails to write to downstream: {e}, {}",
http.request_summary()
);
}
}
let mut task = if !body.is_empty() {
HttpTask::Body(Some(body.into()), true)
} else {
HttpTask::Body(None, true)
};
trace!("{task:?}");
module_ctx.response_filter(&mut task).ok()?;
// TODO: check if chunked encoding is needed
match http.response_duplex_vec(vec![task]).await {
Ok(_) => debug!("HTTP response written."),
Err(e) => error!(
"HTTP server fails to write to downstream: {e}, {}",
http.request_summary()
),
}
match http.finish().await {
Ok(c) => c,
Err(e) => {
error!("HTTP server fails to finish the request: {e}");
None
}
}
}
}

View file

@ -0,0 +1,135 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! The abstraction and implementation interface for service application logic
pub mod http_app;
pub mod prometheus_http_app;
use crate::server::ShutdownWatch;
use async_trait::async_trait;
use log::{debug, error};
use std::sync::Arc;
use crate::protocols::http::v2::server;
use crate::protocols::http::ServerSession;
use crate::protocols::Stream;
use crate::protocols::ALPN;
#[cfg_attr(not(doc_async_trait), async_trait)]
/// This trait defines the interface of a transport layer (TCP or TLS) application.
pub trait ServerApp {
/// Whenever a new connection is established, this function will be called with the established
/// [`Stream`] object provided.
///
/// The application can do whatever it wants with the `session`.
///
/// After processing the `session`, if the `session`'s connection is reusable, This function
/// can return it to the service by returning `Some(session)`. The returned `session` will be
/// fed to another [`Self::process_new()`] for another round of processing.
/// If not reusable, `None` should be returned.
///
/// The `shutdown` argument will change from `false` to `true` when the server receives a
/// signal to shutdown. This argument allows the application to react accordingly.
async fn process_new(
self: &Arc<Self>,
mut session: Stream,
// TODO: make this ShutdownWatch so that all task can await on this event
shutdown: &ShutdownWatch,
) -> Option<Stream>;
/// This callback will be called once after the service stops listening to its endpoints.
fn cleanup(&self) {}
}
/// This trait defines the interface of a HTTP application.
#[cfg_attr(not(doc_async_trait), async_trait)]
pub trait HttpServerApp {
/// Similar to the [`ServerApp`], this function is called whenever a new HTTP session is established.
///
/// After successful processing, [`ServerSession::finish()`] can be called to return an optionally reusable
/// connection back to the service. The caller needs to make sure that the connection is in a reusable state
/// i.e., no error or incomplete read or write headers or bodies. Otherwise a `None` should be returned.
async fn process_new_http(
self: &Arc<Self>,
mut session: ServerSession,
// TODO: make this ShutdownWatch so that all task can await on this event
shutdown: &ShutdownWatch,
) -> Option<Stream>;
/// Provide options on how HTTP/2 connection should be established. This function will be called
/// every time a new HTTP/2 **connection** needs to be established.
///
/// A `None` means to use the built-in default options. See [`server::H2Options`] for more details.
fn h2_options(&self) -> Option<server::H2Options> {
None
}
fn http_cleanup(&self) {}
}
#[cfg_attr(not(doc_async_trait), async_trait)]
impl<T> ServerApp for T
where
T: HttpServerApp + Send + Sync + 'static,
{
async fn process_new(
self: &Arc<Self>,
stream: Stream,
shutdown: &ShutdownWatch,
) -> Option<Stream> {
match stream.selected_alpn_proto() {
Some(ALPN::H2) => {
let h2_options = self.h2_options();
let h2_conn = server::handshake(stream, h2_options).await;
let mut h2_conn = match h2_conn {
Err(e) => {
error!("H2 handshake error {e}");
return None;
}
Ok(c) => c,
};
loop {
// this loop ends when the client decides to close the h2 conn
// TODO: add a timeout?
let h2_stream = server::HttpSession::from_h2_conn(&mut h2_conn).await;
let h2_stream = match h2_stream {
Err(e) => {
// It is common for client to just disconnect TCP without properly
// closing H2. So we don't log the errors here
debug!("H2 error when accepting new stream {e}");
return None;
}
Ok(s) => s?, // None means the connection is ready to be closed
};
let app = self.clone();
let shutdown = shutdown.clone();
pingora_runtime::current_handle().spawn(async move {
app.process_new_http(ServerSession::new_http2(h2_stream), &shutdown)
.await;
});
}
}
_ => {
// No ALPN or ALPN::H1 or something else, just try Http1
self.process_new_http(ServerSession::new_http1(stream), shutdown)
.await
}
}
}
fn cleanup(&self) {
self.http_cleanup()
}
}

View file

@ -0,0 +1,60 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A HTTP application that reports Prometheus metrics.
use async_trait::async_trait;
use http::{self, Response};
use prometheus::{Encoder, TextEncoder};
use super::http_app::HttpServer;
use crate::apps::http_app::ServeHttp;
use crate::modules::http::compression::ResponseCompressionBuilder;
use crate::protocols::http::ServerSession;
/// A HTTP application that reports Prometheus metrics.
///
/// This application will report all the [static metrics](https://docs.rs/prometheus/latest/prometheus/index.html#static-metrics)
/// collected via the [Prometheus](https://docs.rs/prometheus/) crate;
pub struct PrometheusHttpApp;
#[cfg_attr(not(doc_async_trait), async_trait)]
impl ServeHttp for PrometheusHttpApp {
async fn response(&self, _http_session: &mut ServerSession) -> Response<Vec<u8>> {
let encoder = TextEncoder::new();
let metric_families = prometheus::gather();
let mut buffer = vec![];
encoder.encode(&metric_families, &mut buffer).unwrap();
Response::builder()
.status(200)
.header(http::header::CONTENT_TYPE, encoder.format_type())
.header(http::header::CONTENT_LENGTH, buffer.len())
.body(buffer)
.unwrap()
}
}
/// The [HttpServer] for [PrometheusHttpApp]
///
/// This type provides the functionality of [PrometheusHttpApp] with compression enabled
pub type PrometheusServer = HttpServer<PrometheusHttpApp>;
impl PrometheusServer {
pub fn new() -> Self {
let mut server = Self::new_app(PrometheusHttpApp);
// enable gzip level 7 compression
server.add_module(ResponseCompressionBuilder::enable(7));
server
}
}

View file

@ -0,0 +1,221 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Connecting to HTTP servers
use crate::connectors::ConnectorOptions;
use crate::protocols::http::client::HttpSession;
use crate::upstreams::peer::Peer;
use pingora_error::Result;
use std::time::Duration;
pub mod v1;
pub mod v2;
pub struct Connector {
h1: v1::Connector,
h2: v2::Connector,
}
impl Connector {
pub fn new(options: Option<ConnectorOptions>) -> Self {
Connector {
h1: v1::Connector::new(options.clone()),
h2: v2::Connector::new(options),
}
}
/// Get an [HttpSession] to the given server.
///
/// The second return value indicates whether the session is connected via a reused stream.
pub async fn get_http_session<P: Peer + Send + Sync + 'static>(
&self,
peer: &P,
) -> Result<(HttpSession, bool)> {
// NOTE: maybe TODO: we do not yet enforce that only TLS traffic can use h2, which is the
// de facto requirement for h2, because non TLS traffic lack the negotiation mechanism.
// We assume no peer option == no ALPN == h1 only
let h1_only = peer
.get_peer_options()
.map_or(true, |o| o.alpn.get_max_http_version() == 1);
if h1_only {
let (h1, reused) = self.h1.get_http_session(peer).await?;
Ok((HttpSession::H1(h1), reused))
} else {
// the peer allows h2, we first check the h2 reuse pool
let reused_h2 = self.h2.reused_http_session(peer).await?;
if let Some(h2) = reused_h2 {
return Ok((HttpSession::H2(h2), true));
}
let h2_only = peer
.get_peer_options()
.map_or(false, |o| o.alpn.get_min_http_version() == 2)
&& !self.h2.h1_is_preferred(peer);
if !h2_only {
// We next check the reuse pool for h1 before creating a new h2 connection.
// This is because the server may not support h2 at all, connections to
// the server could all be h1.
if let Some(h1) = self.h1.reused_http_session(peer).await {
return Ok((HttpSession::H1(h1), true));
}
}
let session = self.h2.new_http_session(peer).await?;
Ok((session, false))
}
}
pub async fn release_http_session<P: Peer + Send + Sync + 'static>(
&self,
session: HttpSession,
peer: &P,
idle_timeout: Option<Duration>,
) {
match session {
HttpSession::H1(h1) => self.h1.release_http_session(h1, peer, idle_timeout).await,
HttpSession::H2(h2) => self.h2.release_http_session(h2, peer, idle_timeout),
}
}
/// Tell the connector to always send h1 for ALPN for the given peer in the future.
pub fn prefer_h1(&self, peer: &impl Peer) {
self.h2.prefer_h1(peer);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::http::v1::client::HttpSession as Http1Session;
use crate::upstreams::peer::HttpPeer;
use pingora_http::RequestHeader;
async fn get_http(http: &mut Http1Session, expected_status: u16) {
let mut req = Box::new(RequestHeader::build("GET", b"/", None).unwrap());
req.append_header("Host", "one.one.one.one").unwrap();
http.write_request_header(req).await.unwrap();
http.read_response().await.unwrap();
http.respect_keepalive();
assert_eq!(http.get_status().unwrap(), expected_status);
while http.read_body_bytes().await.unwrap().is_some() {}
}
#[tokio::test]
async fn test_connect_h2() {
let connector = Connector::new(None);
let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into());
peer.options.set_http_version(2, 2);
let (h2, reused) = connector.get_http_session(&peer).await.unwrap();
assert!(!reused);
match &h2 {
HttpSession::H1(_) => panic!("expect h2"),
HttpSession::H2(h2_stream) => assert!(!h2_stream.ping_timedout()),
}
connector.release_http_session(h2, &peer, None).await;
let (h2, reused) = connector.get_http_session(&peer).await.unwrap();
// reused this time
assert!(reused);
match &h2 {
HttpSession::H1(_) => panic!("expect h2"),
HttpSession::H2(h2_stream) => assert!(!h2_stream.ping_timedout()),
}
}
#[tokio::test]
async fn test_connect_h1() {
let connector = Connector::new(None);
let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into());
peer.options.set_http_version(1, 1);
let (mut h1, reused) = connector.get_http_session(&peer).await.unwrap();
assert!(!reused);
match &mut h1 {
HttpSession::H1(http) => {
get_http(http, 200).await;
}
HttpSession::H2(_) => panic!("expect h1"),
}
connector.release_http_session(h1, &peer, None).await;
let (mut h1, reused) = connector.get_http_session(&peer).await.unwrap();
// reused this time
assert!(reused);
match &mut h1 {
HttpSession::H1(_) => {}
HttpSession::H2(_) => panic!("expect h1"),
}
}
#[tokio::test]
async fn test_connect_h2_fallback_h1_reuse() {
// this test verify that if the server doesn't support h2, the Connector will reuse the
// h1 session instead.
let connector = Connector::new(None);
let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into());
// As it is hard to find a server that support only h1, we use the following hack to trick
// the connector to think the server supports only h1. We force ALPN to use h1 and then
// return the connection to the Connector. And then we use a Peer that allows h2
peer.options.set_http_version(1, 1);
let (mut h1, reused) = connector.get_http_session(&peer).await.unwrap();
assert!(!reused);
match &mut h1 {
HttpSession::H1(http) => {
get_http(http, 200).await;
}
HttpSession::H2(_) => panic!("expect h1"),
}
connector.release_http_session(h1, &peer, None).await;
let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into());
peer.options.set_http_version(2, 1);
let (mut h1, reused) = connector.get_http_session(&peer).await.unwrap();
// reused this time
assert!(reused);
match &mut h1 {
HttpSession::H1(_) => {}
HttpSession::H2(_) => panic!("expect h1"),
}
}
#[tokio::test]
async fn test_connect_prefer_h1() {
let connector = Connector::new(None);
let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into());
peer.options.set_http_version(2, 1);
connector.prefer_h1(&peer);
let (mut h1, reused) = connector.get_http_session(&peer).await.unwrap();
assert!(!reused);
match &mut h1 {
HttpSession::H1(http) => {
get_http(http, 200).await;
}
HttpSession::H2(_) => panic!("expect h1"),
}
connector.release_http_session(h1, &peer, None).await;
peer.options.set_http_version(2, 2);
let (mut h1, reused) = connector.get_http_session(&peer).await.unwrap();
// reused this time
assert!(reused);
match &mut h1 {
HttpSession::H1(_) => {}
HttpSession::H2(_) => panic!("expect h1"),
}
}
}

View file

@ -0,0 +1,119 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::connectors::{ConnectorOptions, TransportConnector};
use crate::protocols::http::v1::client::HttpSession;
use crate::upstreams::peer::Peer;
use pingora_error::Result;
use std::time::Duration;
pub struct Connector {
transport: TransportConnector,
}
impl Connector {
pub fn new(options: Option<ConnectorOptions>) -> Self {
Connector {
transport: TransportConnector::new(options),
}
}
pub async fn get_http_session<P: Peer + Send + Sync + 'static>(
&self,
peer: &P,
) -> Result<(HttpSession, bool)> {
let (stream, reused) = self.transport.get_stream(peer).await?;
let http = HttpSession::new(stream);
Ok((http, reused))
}
pub async fn reused_http_session<P: Peer + Send + Sync + 'static>(
&self,
peer: &P,
) -> Option<HttpSession> {
self.transport
.reused_stream(peer)
.await
.map(HttpSession::new)
}
pub async fn release_http_session<P: Peer + Send + Sync + 'static>(
&self,
session: HttpSession,
peer: &P,
idle_timeout: Option<Duration>,
) {
if let Some(stream) = session.reuse().await {
self.transport
.release_stream(stream, peer.reuse_hash(), idle_timeout);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::upstreams::peer::HttpPeer;
use pingora_http::RequestHeader;
async fn get_http(http: &mut HttpSession, expected_status: u16) {
let mut req = Box::new(RequestHeader::build("GET", b"/", None).unwrap());
req.append_header("Host", "one.one.one.one").unwrap();
http.write_request_header(req).await.unwrap();
http.read_response().await.unwrap();
http.respect_keepalive();
assert_eq!(http.get_status().unwrap(), expected_status);
while http.read_body_bytes().await.unwrap().is_some() {}
}
#[tokio::test]
async fn test_connect() {
let connector = Connector::new(None);
let peer = HttpPeer::new(("1.1.1.1", 80), false, "".into());
// make a new connection to 1.1.1.1
let (http, reused) = connector.get_http_session(&peer).await.unwrap();
assert!(!reused);
// this http is not even used, so not be able to reuse
connector.release_http_session(http, &peer, None).await;
let (mut http, reused) = connector.get_http_session(&peer).await.unwrap();
assert!(!reused);
get_http(&mut http, 301).await;
connector.release_http_session(http, &peer, None).await;
let (_, reused) = connector.get_http_session(&peer).await.unwrap();
assert!(reused);
}
#[tokio::test]
async fn test_connect_tls() {
let connector = Connector::new(None);
let peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into());
// make a new connection to https://1.1.1.1
let (http, reused) = connector.get_http_session(&peer).await.unwrap();
assert!(!reused);
// this http is not even used, so not be able to reuse
connector.release_http_session(http, &peer, None).await;
let (mut http, reused) = connector.get_http_session(&peer).await.unwrap();
assert!(!reused);
get_http(&mut http, 200).await;
connector.release_http_session(http, &peer, None).await;
let (_, reused) = connector.get_http_session(&peer).await.unwrap();
assert!(reused);
}
}

View file

@ -0,0 +1,531 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::HttpSession;
use crate::connectors::{ConnectorOptions, TransportConnector};
use crate::protocols::http::v1::client::HttpSession as Http1Session;
use crate::protocols::http::v2::client::{drive_connection, Http2Session};
use crate::protocols::{Digest, Stream};
use crate::upstreams::peer::{Peer, ALPN};
use bytes::Bytes;
use h2::client::SendRequest;
use log::debug;
use parking_lot::RwLock;
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use pingora_pool::{ConnectionMeta, ConnectionPool, PoolNode};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
struct Stub(SendRequest<Bytes>);
impl Stub {
async fn new_stream(&self) -> Result<SendRequest<Bytes>> {
let send_req = self.0.clone();
send_req
.ready()
.await
.or_err(H2Error, "while creating new stream")
}
}
pub(crate) struct ConnectionRefInner {
connection_stub: Stub,
closed: watch::Receiver<bool>,
ping_timeout_occurred: Arc<AtomicBool>,
id: i32,
// max concurrent streams this connection is allowed to create
max_streams: usize,
// how many concurrent streams already active
current_streams: AtomicUsize,
// because `SendRequest` doesn't actually have access to the underlying Stream,
// we log info about timing and tcp info here.
pub(crate) digest: Digest,
}
#[derive(Clone)]
pub(crate) struct ConnectionRef(Arc<ConnectionRefInner>);
impl ConnectionRef {
pub fn new(
send_req: SendRequest<Bytes>,
closed: watch::Receiver<bool>,
ping_timeout_occurred: Arc<AtomicBool>,
id: i32,
max_streams: usize,
digest: Digest,
) -> Self {
ConnectionRef(Arc::new(ConnectionRefInner {
connection_stub: Stub(send_req),
closed,
ping_timeout_occurred,
id,
max_streams,
current_streams: AtomicUsize::new(0),
digest,
}))
}
pub fn more_streams_allowed(&self) -> bool {
self.0.max_streams > self.0.current_streams.load(Ordering::Relaxed)
}
pub fn is_idle(&self) -> bool {
self.0.current_streams.load(Ordering::Relaxed) == 0
}
pub fn release_stream(&self) {
self.0.current_streams.fetch_sub(1, Ordering::SeqCst);
}
pub fn id(&self) -> i32 {
self.0.id
}
pub fn digest(&self) -> &Digest {
&self.0.digest
}
pub fn ping_timedout(&self) -> bool {
self.0.ping_timeout_occurred.load(Ordering::Relaxed)
}
pub fn is_closed(&self) -> bool {
*self.0.closed.borrow()
}
// spawn a stream if more stream is allowed, otherwise return Ok(None)
pub async fn spawn_stream(&self) -> Result<Option<Http2Session>> {
// Atomically check if the current_stream is over the limit
// load(), compare and then fetch_add() cannot guarantee the same
let current_streams = self.0.current_streams.fetch_add(1, Ordering::SeqCst);
if current_streams >= self.0.max_streams {
// already over the limit, reset the counter to the previous value
self.0.current_streams.fetch_sub(1, Ordering::SeqCst);
return Ok(None);
}
let send_req = self.0.connection_stub.new_stream().await.map_err(|e| {
// fail to create the stream, reset the counter
self.0.current_streams.fetch_sub(1, Ordering::SeqCst);
e
})?;
Ok(Some(Http2Session::new(send_req, self.clone())))
}
}
struct InUsePool {
// TODO: use pingora hashmap to shard the lock contention
pools: RwLock<HashMap<u64, PoolNode<ConnectionRef>>>,
}
impl InUsePool {
fn new() -> Self {
InUsePool {
pools: RwLock::new(HashMap::new()),
}
}
fn insert(&self, reuse_hash: u64, conn: ConnectionRef) {
{
let pools = self.pools.read();
if let Some(pool) = pools.get(&reuse_hash) {
pool.insert(conn.id(), conn);
return;
}
} // drop read lock
let pool = PoolNode::new();
pool.insert(conn.id(), conn);
let mut pools = self.pools.write();
pools.insert(reuse_hash, pool);
}
// retrieve a h2 conn ref to create a new stream
// the caller should return the conn ref to this pool if there are still
// capacity left for more streams
fn get(&self, reuse_hash: u64) -> Option<ConnectionRef> {
let pools = self.pools.read();
pools.get(&reuse_hash)?.get_any().map(|v| v.1)
}
// release a h2_stream, this functional will cause an ConnectionRef to be returned (if exist)
// the caller should update the ref and then decide where to put it (in use pool or idle)
fn release(&self, reuse_hash: u64, id: i32) -> Option<ConnectionRef> {
let pools = self.pools.read();
if let Some(pool) = pools.get(&reuse_hash) {
pool.remove(id)
} else {
None
}
}
}
const DEFAULT_POOL_SIZE: usize = 128;
/// Http2 connector
pub struct Connector {
// just for creating connections, the Stream of h2 should be reused
transport: TransportConnector,
// the h2 connection idle pool
idle_pool: Arc<ConnectionPool<ConnectionRef>>,
// the pool of h2 connections that have ongoing streams
in_use_pool: InUsePool,
}
impl Connector {
/// Create a new [Connector] from the given [ConnectorOptions]
pub fn new(options: Option<ConnectorOptions>) -> Self {
let pool_size = options
.as_ref()
.map_or(DEFAULT_POOL_SIZE, |o| o.keepalive_pool_size);
// connection offload is handled by the [TransportConnector]
Connector {
transport: TransportConnector::new(options),
idle_pool: Arc::new(ConnectionPool::new(pool_size)),
in_use_pool: InUsePool::new(),
}
}
/// Create a new Http2 connection to the given server
///
/// Either an Http2 or Http1 session can be returned depending on the server's preference.
pub async fn new_http_session<P: Peer + Send + Sync + 'static>(
&self,
peer: &P,
) -> Result<HttpSession> {
let stream = self.transport.new_stream(peer).await?;
// check alpn
match stream.selected_alpn_proto() {
Some(ALPN::H2) => { /* continue */ }
Some(_) => {
// H2 not supported
return Ok(HttpSession::H1(Http1Session::new(stream)));
}
None => {
// if tls but no ALPN, default to h1
// else if plaintext and min http version is 1, this is most likely h1
if peer.tls()
|| peer
.get_peer_options()
.map_or(true, |o| o.alpn.get_min_http_version() == 1)
{
return Ok(HttpSession::H1(Http1Session::new(stream)));
}
// else: min http version=H2 over plaintext, there is no ALPN anyways, we trust
// the caller that the server speaks h2c
}
}
let max_h2_stream = peer.get_peer_options().map_or(1, |o| o.max_h2_streams);
let conn = handshake(stream, max_h2_stream, peer.h2_ping_interval()).await?;
let h2_stream = conn
.spawn_stream()
.await?
.expect("newly created connections should have at least one free stream");
if conn.more_streams_allowed() {
self.in_use_pool.insert(peer.reuse_hash(), conn);
}
Ok(HttpSession::H2(h2_stream))
}
/// Try to create a new http2 stream from any existing H2 connection.
///
/// None means there is no "free" connection left.
pub async fn reused_http_session<P: Peer + Send + Sync + 'static>(
&self,
peer: &P,
) -> Result<Option<Http2Session>> {
// check in use pool first so that we use fewer total connections
// then idle pool
let reuse_hash = peer.reuse_hash();
// NOTE: We grab a conn from the pools, create a new stream and put the conn back if the
// conn has more free streams. During this process another caller could arrive but is not
// able to find the conn even the conn has free stream to use.
// We accept this false negative to keep the implementation simple. This false negative
// makes an actual impact when there are only a few connection.
// Alternative design 1. given each free stream a conn object: a lot of Arc<>
// Alternative design 2. mutex the pool, which creates lock contention when concurrency is high
// Alternative design 3. do not pop conn from the pool so that multiple callers can grab it
// which will cause issue where spawn_stream() could return None because others call it
// first. Thus a caller might have to retry or give up. This issue is more likely to happen
// when concurrency is high.
let maybe_conn = self
.in_use_pool
.get(reuse_hash)
.or_else(|| self.idle_pool.get(&reuse_hash));
if let Some(conn) = maybe_conn {
let h2_stream = conn
.spawn_stream()
.await?
.expect("connection from the pools should have free stream to allocate");
if conn.more_streams_allowed() {
self.in_use_pool.insert(reuse_hash, conn);
}
Ok(Some(h2_stream))
} else {
Ok(None)
}
}
/// Release a finished h2 stream.
///
/// This function will terminate the [Http2Session]. The corresponding h2 connection will now
/// have one more free stream to use.
///
/// The h2 connection will be closed after `idle_timeout` if it has no active streams.
pub fn release_http_session<P: Peer + Send + Sync + 'static>(
&self,
session: Http2Session,
peer: &P,
idle_timeout: Option<Duration>,
) {
let id = session.conn.id();
let reuse_hash = peer.reuse_hash();
// get a ref to the connection, which we might need below, before dropping the h2
let conn = session.conn();
// this drop() will both drop the actual stream and call the conn.release_stream()
drop(session);
// find and remove the conn stored in in_use_pool so that it could be put in the idle pool
// if necessary
let conn = self.in_use_pool.release(reuse_hash, id).unwrap_or(conn);
if conn.is_closed() {
// Already dead h2 connection
return;
}
if conn.is_idle() {
let meta = ConnectionMeta {
key: reuse_hash,
id,
};
let closed = conn.0.closed.clone();
let (notify_evicted, watch_use) = self.idle_pool.put(&meta, conn);
if let Some(to) = idle_timeout {
let pool = self.idle_pool.clone(); //clone the arc
let rt = pingora_runtime::current_handle();
rt.spawn(async move {
pool.idle_timeout(&meta, to, notify_evicted, closed, watch_use)
.await;
});
}
} else {
self.in_use_pool.insert(reuse_hash, conn);
}
}
/// Tell the connector to always send h1 for ALPN for the given peer in the future.
pub fn prefer_h1(&self, peer: &impl Peer) {
self.transport.prefer_h1(peer);
}
pub(crate) fn h1_is_preferred(&self, peer: &impl Peer) -> bool {
self.transport
.preferred_http_version
.get(peer)
.map_or(false, |v| matches!(v, ALPN::H1))
}
}
// The h2 library we use has unbounded internal buffering, which will cause excessive memory
// consumption when the downstream is slower than upstream. This window size caps the buffering by
// limiting how much data can be inflight. However, setting this value will also cap the max
// download speed by limiting the bandwidth-delay product of a link.
// Long term, we should advertising large window but shrink it when a small buffer is full.
// 8 Mbytes = 80 Mbytes X 100ms, which should be enough for most links.
const H2_WINDOW_SIZE: u32 = 1 << 23;
async fn handshake(
stream: Stream,
max_streams: usize,
h2_ping_interval: Option<Duration>,
) -> Result<ConnectionRef> {
use h2::client::Builder;
use pingora_runtime::current_handle;
// Safe guard: new_http_session() assumes there should be at least one free stream
if max_streams == 0 {
return Error::e_explain(H2Error, "zero max_stream configured");
}
let id = stream.id();
let digest = Digest {
// NOTE: this field is always false because the digest is shared across all streams
// The streams should log their own reuse info
ssl_digest: stream.get_ssl_digest(),
// TODO: log h2 handshake time
timing_digest: stream.get_timing_digest(),
proxy_digest: stream.get_proxy_digest(),
};
// TODO: make these configurable
let (send_req, connection) = Builder::new()
.enable_push(false)
.initial_max_send_streams(max_streams)
// The limit for the server. Server push is not allowed, so this value doesn't matter
.max_concurrent_streams(1)
.max_frame_size(64 * 1024) // advise server to send larger frames
.initial_window_size(H2_WINDOW_SIZE)
// should this be max_streams * H2_WINDOW_SIZE?
.initial_connection_window_size(H2_WINDOW_SIZE)
.handshake(stream)
.await
.or_err(HandshakeError, "during H2 handshake")?;
debug!("H2 handshake to server done.");
let ping_timeout_occurred = Arc::new(AtomicBool::new(false));
let ping_timeout_clone = ping_timeout_occurred.clone();
let max_allowed_streams = std::cmp::min(max_streams, connection.max_concurrent_send_streams());
// Safe guard: new_http_session() assumes there should be at least one free stream
// The server won't commonly advertise 0 max stream.
if max_allowed_streams == 0 {
return Error::e_explain(H2Error, "zero max_concurrent_send_streams received");
}
let (closed_tx, closed_rx) = watch::channel(false);
current_handle().spawn(async move {
drive_connection(
connection,
id,
closed_tx,
h2_ping_interval,
ping_timeout_clone,
)
.await;
});
Ok(ConnectionRef::new(
send_req,
closed_rx,
ping_timeout_occurred,
id,
max_allowed_streams,
digest,
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::upstreams::peer::HttpPeer;
#[tokio::test]
async fn test_connect_h2() {
let connector = Connector::new(None);
let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into());
peer.options.set_http_version(2, 2);
let h2 = connector.new_http_session(&peer).await.unwrap();
match h2 {
HttpSession::H1(_) => panic!("expect h2"),
HttpSession::H2(h2_stream) => assert!(!h2_stream.ping_timedout()),
}
}
#[tokio::test]
async fn test_connect_h1() {
let connector = Connector::new(None);
let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into());
// a hack to force h1, new_http_session() in the future might validate this setting
peer.options.set_http_version(1, 1);
let h2 = connector.new_http_session(&peer).await.unwrap();
match h2 {
HttpSession::H1(_) => {}
HttpSession::H2(_) => panic!("expect h1"),
}
}
#[tokio::test]
async fn test_connect_h1_plaintext() {
let connector = Connector::new(None);
let mut peer = HttpPeer::new(("1.1.1.1", 80), false, "".into());
peer.options.set_http_version(2, 1);
let h2 = connector.new_http_session(&peer).await.unwrap();
match h2 {
HttpSession::H1(_) => {}
HttpSession::H2(_) => panic!("expect h1"),
}
}
#[tokio::test]
async fn test_h2_single_stream() {
let connector = Connector::new(None);
let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into());
peer.options.set_http_version(2, 2);
peer.options.max_h2_streams = 1;
let h2 = connector.new_http_session(&peer).await.unwrap();
let h2_1 = match h2 {
HttpSession::H1(_) => panic!("expect h2"),
HttpSession::H2(h2_stream) => h2_stream,
};
let id = h2_1.conn.id();
assert!(connector
.reused_http_session(&peer)
.await
.unwrap()
.is_none());
connector.release_http_session(h2_1, &peer, None);
let h2_2 = connector.reused_http_session(&peer).await.unwrap().unwrap();
assert_eq!(id, h2_2.conn.id());
connector.release_http_session(h2_2, &peer, None);
let h2_3 = connector.reused_http_session(&peer).await.unwrap().unwrap();
assert_eq!(id, h2_3.conn.id());
}
#[tokio::test]
async fn test_h2_multiple_stream() {
let connector = Connector::new(None);
let mut peer = HttpPeer::new(("1.1.1.1", 443), true, "one.one.one.one".into());
peer.options.set_http_version(2, 2);
peer.options.max_h2_streams = 3;
let h2 = connector.new_http_session(&peer).await.unwrap();
let h2_1 = match h2 {
HttpSession::H1(_) => panic!("expect h2"),
HttpSession::H2(h2_stream) => h2_stream,
};
let id = h2_1.conn.id();
let h2_2 = connector.reused_http_session(&peer).await.unwrap().unwrap();
assert_eq!(id, h2_2.conn.id());
let h2_3 = connector.reused_http_session(&peer).await.unwrap().unwrap();
assert_eq!(id, h2_3.conn.id());
// max stream is 3 for now
assert!(connector
.reused_http_session(&peer)
.await
.unwrap()
.is_none());
connector.release_http_session(h2_1, &peer, None);
let h2_4 = connector.reused_http_session(&peer).await.unwrap().unwrap();
assert_eq!(id, h2_4.conn.id());
connector.release_http_session(h2_2, &peer, None);
connector.release_http_session(h2_3, &peer, None);
connector.release_http_session(h2_4, &peer, None);
// all streams are released, now the connection is idle
let h2_5 = connector.reused_http_session(&peer).await.unwrap().unwrap();
assert_eq!(id, h2_5.conn.id());
}
}

View file

@ -0,0 +1,313 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use log::debug;
use pingora_error::{Context, Error, ErrorType::*, OrErr, Result};
use rand::seq::SliceRandom;
use std::net::SocketAddr as InetSocketAddr;
use crate::protocols::l4::ext::{connect as tcp_connect, connect_uds, set_tcp_keepalive};
use crate::protocols::l4::socket::SocketAddr;
use crate::protocols::l4::stream::Stream;
use crate::upstreams::peer::Peer;
/// Establish a connection (l4) to the given peer using its settings and an optional bind address.
pub async fn connect<P>(peer: &P, bind_to: Option<InetSocketAddr>) -> Result<Stream>
where
P: Peer + Send + Sync,
{
if peer.get_proxy().is_some() {
return proxy_connect(peer)
.await
.err_context(|| format!("Fail to establish CONNECT proxy: {}", peer));
}
let mut stream: Stream = match peer.address() {
SocketAddr::Inet(addr) => {
let connect_future = tcp_connect(addr, bind_to.as_ref());
let conn_res = match peer.connection_timeout() {
Some(t) => pingora_timeout::timeout(t, connect_future)
.await
.explain_err(ConnectTimedout, |_| {
format!("timeout {t:?} connecting to server {peer}")
})?,
None => connect_future.await,
};
match conn_res {
Ok(socket) => {
debug!("connected to new server: {}", peer.address());
if let Some(ka) = peer.tcp_keepalive() {
debug!("Setting tcp keepalive");
set_tcp_keepalive(&socket, ka)?;
}
Ok(socket.into())
}
Err(e) => {
let c = format!("Fail to connect to {peer}");
match e.etype() {
SocketError | BindError => Error::e_because(InternalError, c, e),
_ => Err(e.more_context(c)),
}
}
}
}
SocketAddr::Unix(addr) => {
let connect_future = connect_uds(
addr.as_pathname()
.expect("non-pathname unix sockets not supported as peer"),
);
let conn_res = match peer.connection_timeout() {
Some(t) => pingora_timeout::timeout(t, connect_future)
.await
.explain_err(ConnectTimedout, |_| {
format!("timeout {t:?} connecting to server {peer}")
})?,
None => connect_future.await,
};
match conn_res {
Ok(socket) => {
debug!("connected to new server: {}", peer.address());
// no SO_KEEPALIVE for UDS
Ok(socket.into())
}
Err(e) => {
let c = format!("Fail to connect to {peer}");
match e.etype() {
SocketError | BindError => Error::e_because(InternalError, c, e),
_ => Err(e.more_context(c)),
}
}
}
}
}?;
let tracer = peer.get_tracer();
if let Some(t) = tracer {
t.0.on_connected();
stream.tracer = Some(t);
}
stream.set_nodelay()?;
Ok(stream)
}
pub(crate) fn bind_to_random<P: Peer>(
peer: &P,
v4_list: &[InetSocketAddr],
v6_list: &[InetSocketAddr],
) -> Option<InetSocketAddr> {
let selected = peer.get_peer_options().and_then(|o| o.bind_to);
if selected.is_some() {
return selected;
}
fn bind_to_ips(ips: &[InetSocketAddr]) -> Option<InetSocketAddr> {
match ips.len() {
0 => None,
1 => Some(ips[0]),
_ => {
// pick a random bind ip
ips.choose(&mut rand::thread_rng()).copied()
}
}
}
match peer.address() {
SocketAddr::Inet(sockaddr) => match sockaddr {
InetSocketAddr::V4(_) => bind_to_ips(v4_list),
InetSocketAddr::V6(_) => bind_to_ips(v6_list),
},
SocketAddr::Unix(_) => None,
}
}
use crate::protocols::raw_connect;
async fn proxy_connect<P: Peer>(peer: &P) -> Result<Stream> {
// safe to unwrap
let proxy = peer.get_proxy().unwrap();
let options = peer.get_peer_options().unwrap();
// combine required and optional headers
let mut headers = proxy
.headers
.iter()
.chain(options.extra_proxy_headers.iter());
// not likely to timeout during connect() to UDS
let stream: Box<Stream> = Box::new(
connect_uds(&proxy.next_hop)
.await
.or_err_with(ConnectError, || {
format!("CONNECT proxy connect() error to {:?}", &proxy.next_hop)
})?
.into(),
);
let req_header = raw_connect::generate_connect_header(&proxy.host, proxy.port, &mut headers)?;
let fut = raw_connect::connect(stream, &req_header);
let (mut stream, digest) = match peer.connection_timeout() {
Some(t) => pingora_timeout::timeout(t, fut)
.await
.explain_err(ConnectTimedout, |_| "establishing CONNECT proxy")?,
None => fut.await,
}
.map_err(|mut e| {
// http protocol may ask to retry if reused client
e.retry.decide_reuse(false);
e
})?;
debug!("CONNECT proxy established: {:?}", proxy);
stream.set_proxy_digest(digest);
let stream = stream.into_any().downcast::<Stream>().unwrap(); // safe, it is Stream from above
Ok(*stream)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::upstreams::peer::{BasicPeer, HttpPeer, Proxy};
use std::collections::BTreeMap;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
use tokio::net::UnixListener;
#[tokio::test]
async fn test_conn_error_refused() {
let peer = BasicPeer::new("127.0.0.1:79"); // hopefully port 79 is not used
let new_session = connect(&peer, None).await;
assert_eq!(new_session.unwrap_err().etype(), &ConnectRefused)
}
// TODO broken on arm64
#[ignore]
#[tokio::test]
async fn test_conn_error_no_route() {
let peer = BasicPeer::new("[::3]:79"); // no route
let new_session = connect(&peer, None).await;
assert_eq!(new_session.unwrap_err().etype(), &ConnectNoRoute)
}
#[tokio::test]
async fn test_conn_error_addr_not_avail() {
let peer = HttpPeer::new("127.0.0.1:121".to_string(), false, "".to_string());
let new_session = connect(&peer, Some("192.0.2.2:0".parse().unwrap())).await;
assert_eq!(new_session.unwrap_err().etype(), &InternalError)
}
#[tokio::test]
async fn test_conn_error_other() {
let peer = HttpPeer::new("240.0.0.1:80".to_string(), false, "".to_string()); // non localhost
// create an error: cannot send from src addr: localhost to dst addr: a public IP
let new_session = connect(&peer, Some("127.0.0.1:0".parse().unwrap())).await;
let error = new_session.unwrap_err();
// XXX: some system will allow the socket to bind and connect without error, only to timeout
assert!(error.etype() == &ConnectError || error.etype() == &ConnectTimedout)
}
#[tokio::test]
async fn test_conn_timeout() {
// 192.0.2.1 is effectively a blackhole
let mut peer = BasicPeer::new("192.0.2.1:79");
peer.options.connection_timeout = Some(std::time::Duration::from_millis(1)); //1ms
let new_session = connect(&peer, None).await;
assert_eq!(new_session.unwrap_err().etype(), &ConnectTimedout)
}
#[tokio::test]
async fn test_connect_proxy_fail() {
let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());
let mut path = PathBuf::new();
path.push("/tmp/123");
peer.proxy = Some(Proxy {
next_hop: path.into(),
host: "1.1.1.1".into(),
port: 80,
headers: BTreeMap::new(),
});
let new_session = connect(&peer, None).await;
let e = new_session.unwrap_err();
assert_eq!(e.etype(), &ConnectError);
assert!(!e.retry());
}
const MOCK_UDS_PATH: &str = "/tmp/test_unix_connect_proxy.sock";
// one-off mock server
async fn mock_connect_server() {
let _ = std::fs::remove_file(MOCK_UDS_PATH);
let listener = UnixListener::bind(MOCK_UDS_PATH).unwrap();
if let Ok((mut stream, _addr)) = listener.accept().await {
stream.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap();
// wait a bit so that the client can read
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
let _ = std::fs::remove_file(MOCK_UDS_PATH);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_connect_proxy_work() {
tokio::spawn(async {
mock_connect_server().await;
});
// wait for the server to start
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());
let mut path = PathBuf::new();
path.push(MOCK_UDS_PATH);
peer.proxy = Some(Proxy {
next_hop: path.into(),
host: "1.1.1.1".into(),
port: 80,
headers: BTreeMap::new(),
});
let new_session = connect(&peer, None).await;
assert!(new_session.is_ok());
}
const MOCK_BAD_UDS_PATH: &str = "/tmp/test_unix_bad_connect_proxy.sock";
// one-off mock bad proxy
// closes connection upon accepting
async fn mock_connect_bad_server() {
let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH);
let listener = UnixListener::bind(MOCK_BAD_UDS_PATH).unwrap();
if let Ok((mut stream, _addr)) = listener.accept().await {
stream.shutdown().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_connect_proxy_conn_closed() {
tokio::spawn(async {
mock_connect_bad_server().await;
});
// wait for the server to start
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string());
let mut path = PathBuf::new();
path.push(MOCK_BAD_UDS_PATH);
peer.proxy = Some(Proxy {
next_hop: path.into(),
host: "1.1.1.1".into(),
port: 80,
headers: BTreeMap::new(),
});
let new_session = connect(&peer, None).await;
let err = new_session.unwrap_err();
assert_eq!(err.etype(), &ConnectionClosed);
assert!(!err.retry());
}
}

View file

@ -0,0 +1,477 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Connecting to servers
pub mod http;
mod l4;
mod offload;
mod tls;
use crate::protocols::Stream;
use crate::server::configuration::ServerConf;
use crate::tls::ssl::SslConnector;
use crate::upstreams::peer::{Peer, ALPN};
use l4::connect as l4_connect;
use log::{debug, error, warn};
use offload::OffloadRuntime;
use parking_lot::RwLock;
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use pingora_pool::{ConnectionMeta, ConnectionPool};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
/// The options to configure a [TransportConnector]
#[derive(Clone)]
pub struct ConnectorOptions {
/// Path to the CA file used to validate server certs.
///
/// If `None`, the CA in the [default](https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_default_verify_paths.html)
/// locations will be loaded
pub ca_file: Option<String>,
/// The default client cert and key to use for mTLS
///
/// Each individual connection can use their own cert key to override this.
pub cert_key_file: Option<(String, String)>,
/// How many connections to keepalive
pub keepalive_pool_size: usize,
/// Optionally offload the connection establishment to dedicated thread pools
///
/// TCP and TLS connection establishment can be CPU intensive. Sometimes such tasks can slow
/// down the entire service, which causes timeouts which leads to more connections which
/// snowballs the issue. Use this option to isolate these CPU intensive tasks from impacting
/// other traffic.
///
/// Syntax: (#pools, #thread in each pool)
pub offload_threadpool: Option<(usize, usize)>,
/// Bind to any of the given source IPv6 addresses
pub bind_to_v4: Vec<SocketAddr>,
/// Bind to any of the given source IPv4 addresses
pub bind_to_v6: Vec<SocketAddr>,
}
impl ConnectorOptions {
/// Derive the [ConnectorOptions] from a [ServerConf]
pub fn from_server_conf(server_conf: &ServerConf) -> Self {
// if both pools and threads are Some(>0)
let offload_threadpool = server_conf
.upstream_connect_offload_threadpools
.zip(server_conf.upstream_connect_offload_thread_per_pool)
.filter(|(pools, threads)| *pools > 0 && *threads > 0);
// create SocketAddrs with port 0 for src addr bind
let bind_to_v4 = server_conf
.client_bind_to_ipv4
.iter()
.map(|v4| {
let ip = v4.parse().unwrap();
SocketAddr::new(ip, 0)
})
.collect();
let bind_to_v6 = server_conf
.client_bind_to_ipv6
.iter()
.map(|v6| {
let ip = v6.parse().unwrap();
SocketAddr::new(ip, 0)
})
.collect();
ConnectorOptions {
ca_file: server_conf.ca_file.clone(),
cert_key_file: None, // TODO: use it
keepalive_pool_size: server_conf.upstream_keepalive_pool_size,
offload_threadpool,
bind_to_v4,
bind_to_v6,
}
}
/// Create a new [ConnectorOptions] with the given keepalive pool size
pub fn new(keepalive_pool_size: usize) -> Self {
ConnectorOptions {
ca_file: None,
cert_key_file: None,
keepalive_pool_size,
offload_threadpool: None,
bind_to_v4: vec![],
bind_to_v6: vec![],
}
}
}
/// [TransportConnector] provides APIs to connect to servers via TCP or TLS with connection reuse
pub struct TransportConnector {
tls_ctx: tls::Connector,
connection_pool: Arc<ConnectionPool<Arc<Mutex<Stream>>>>,
offload: Option<OffloadRuntime>,
bind_to_v4: Vec<SocketAddr>,
bind_to_v6: Vec<SocketAddr>,
preferred_http_version: PreferredHttpVersion,
}
const DEFAULT_POOL_SIZE: usize = 128;
impl TransportConnector {
/// Create a new [TransportConnector] with the given [ConnectorOptions]
pub fn new(mut options: Option<ConnectorOptions>) -> Self {
let pool_size = options
.as_ref()
.map_or(DEFAULT_POOL_SIZE, |c| c.keepalive_pool_size);
// Take the offloading setting there because this layer has implement offloading,
// so no need for stacks at lower layer to offload again.
let offload = options.as_mut().and_then(|o| o.offload_threadpool.take());
let bind_to_v4 = options
.as_ref()
.map_or_else(Vec::new, |o| o.bind_to_v4.clone());
let bind_to_v6 = options
.as_ref()
.map_or_else(Vec::new, |o| o.bind_to_v6.clone());
TransportConnector {
tls_ctx: tls::Connector::new(options),
connection_pool: Arc::new(ConnectionPool::new(pool_size)),
offload: offload.map(|v| OffloadRuntime::new(v.0, v.1)),
bind_to_v4,
bind_to_v6,
preferred_http_version: PreferredHttpVersion::new(),
}
}
/// Connect to the given server [Peer]
///
/// No connection is reused.
pub async fn new_stream<P: Peer + Send + Sync + 'static>(&self, peer: &P) -> Result<Stream> {
let rt = self
.offload
.as_ref()
.map(|o| o.get_runtime(peer.reuse_hash()));
let bind_to = l4::bind_to_random(peer, &self.bind_to_v4, &self.bind_to_v6);
let alpn_override = self.preferred_http_version.get(peer);
let stream = if let Some(rt) = rt {
let peer = peer.clone();
let tls_ctx = self.tls_ctx.clone();
rt.spawn(async move { do_connect(&peer, bind_to, alpn_override, &tls_ctx.ctx).await })
.await
.or_err(InternalError, "offload runtime failure")??
} else {
do_connect(peer, bind_to, alpn_override, &self.tls_ctx.ctx).await?
};
Ok(stream)
}
/// Try to find a reusable connection to the given server [Peer]
pub async fn reused_stream<P: Peer + Send + Sync>(&self, peer: &P) -> Option<Stream> {
match self.connection_pool.get(&peer.reuse_hash()) {
Some(s) => {
debug!("find reusable stream, trying to acquire it");
{
let _ = s.lock().await;
} // wait for the idle poll to release it
match Arc::try_unwrap(s) {
Ok(l) => {
let mut stream = l.into_inner();
// test_reusable_stream: we assume server would never actively send data
// first on an idle stream.
if peer.matches_fd(stream.id()) && test_reusable_stream(&mut stream) {
Some(stream)
} else {
None
}
}
Err(_) => {
error!("failed to acquire reusable stream");
None
}
}
}
None => {
debug!("No reusable connection found for {peer}");
None
}
}
}
/// Return the [Stream] to the [TransportConnector] for connection reuse.
///
/// Not all TCP/TLS connection can be reused. It is the caller's responsibility to make sure
/// that protocol over the [Stream] supports connection reuse and the [Stream] itself is ready
/// to be reused.
///
/// If a [Stream] is dropped instead of being returned via this function. it will be closed.
pub fn release_stream(
&self,
mut stream: Stream,
key: u64, // usually peer.reuse_hash()
idle_timeout: Option<std::time::Duration>,
) {
if !test_reusable_stream(&mut stream) {
return;
}
let id = stream.id();
let meta = ConnectionMeta::new(key, id);
debug!("Try to keepalive client session");
let stream = Arc::new(Mutex::new(stream));
let locked_stream = stream.clone().try_lock_owned().unwrap(); // safe as we just created it
let (notify_close, watch_use) = self.connection_pool.put(&meta, stream);
let pool = self.connection_pool.clone(); //clone the arc
let rt = pingora_runtime::current_handle();
rt.spawn(async move {
pool.idle_poll(locked_stream, &meta, idle_timeout, notify_close, watch_use)
.await;
});
}
/// Get a stream to the given server [Peer]
///
/// This function will try to find a reusable [Stream] first. If there is none, a new connection
/// will be made to the server.
///
/// The returned boolean will indicate whether the stream is reused.
pub async fn get_stream<P: Peer + Send + Sync + 'static>(
&self,
peer: &P,
) -> Result<(Stream, bool)> {
let reused_stream = self.reused_stream(peer).await;
if let Some(s) = reused_stream {
Ok((s, true))
} else {
let s = self.new_stream(peer).await?;
Ok((s, false))
}
}
/// Tell the connector to always send h1 for ALPN for the given peer in the future.
pub fn prefer_h1(&self, peer: &impl Peer) {
self.preferred_http_version.add(peer, 1);
}
}
// Perform the actual L4 and tls connection steps while respecting the peer's
// connection timeout if there one
async fn do_connect<P: Peer + Send + Sync>(
peer: &P,
bind_to: Option<SocketAddr>,
alpn_override: Option<ALPN>,
tls_ctx: &SslConnector,
) -> Result<Stream> {
// Create the future that does the connections, but don't evaluate it until
// we decide if we need a timeout or not
let connect_future = do_connect_inner(peer, bind_to, alpn_override, tls_ctx);
match peer.total_connection_timeout() {
Some(t) => match pingora_timeout::timeout(t, connect_future).await {
Ok(res) => res,
Err(_) => Error::e_explain(
ConnectTimedout,
format!("connecting to server {peer}, total-connection timeout {t:?}"),
),
},
None => connect_future.await,
}
}
// Perform the actual L4 and tls connection steps with no timeout
async fn do_connect_inner<P: Peer + Send + Sync>(
peer: &P,
bind_to: Option<SocketAddr>,
alpn_override: Option<ALPN>,
tls_ctx: &SslConnector,
) -> Result<Stream> {
let stream = l4_connect(peer, bind_to).await?;
if peer.tls() {
let tls_stream = tls::connect(stream, peer, alpn_override, tls_ctx).await?;
Ok(Box::new(tls_stream))
} else {
Ok(Box::new(stream))
}
}
struct PreferredHttpVersion {
// TODO: shard to avoid the global lock
versions: RwLock<HashMap<u64, u8>>, // <hash of peer, version>
}
// TODO: limit the size of this
impl PreferredHttpVersion {
pub fn new() -> Self {
PreferredHttpVersion {
versions: RwLock::default(),
}
}
pub fn add(&self, peer: &impl Peer, version: u8) {
let key = peer.reuse_hash();
let mut v = self.versions.write();
v.insert(key, version);
}
pub fn get(&self, peer: &impl Peer) -> Option<ALPN> {
let key = peer.reuse_hash();
let v = self.versions.read();
v.get(&key)
.copied()
.map(|v| if v == 1 { ALPN::H1 } else { ALPN::H2H1 })
}
}
use futures::future::FutureExt;
use tokio::io::AsyncReadExt;
/// Test whether a stream is already closed or not reusable (server sent unexpected data)
fn test_reusable_stream(stream: &mut Stream) -> bool {
let mut buf = [0; 1];
let result = stream.read(&mut buf[..]).now_or_never();
if let Some(data_result) = result {
match data_result {
Ok(n) => {
if n == 0 {
debug!("Idle connection is closed");
} else {
warn!("Unexpected data read in idle connection");
}
}
Err(e) => {
debug!("Idle connection is broken: {e:?}");
}
}
false
} else {
true
}
}
#[cfg(test)]
mod tests {
use pingora_error::ErrorType;
use pingora_openssl::ssl::SslMethod;
use super::*;
use crate::upstreams::peer::BasicPeer;
// 192.0.2.1 is effectively a black hole
const BLACK_HOLE: &str = "192.0.2.1:79";
#[tokio::test]
async fn test_connect() {
let connector = TransportConnector::new(None);
let peer = BasicPeer::new("1.1.1.1:80");
// make a new connection to 1.1.1.1
let stream = connector.new_stream(&peer).await.unwrap();
connector.release_stream(stream, peer.reuse_hash(), None);
let (_, reused) = connector.get_stream(&peer).await.unwrap();
assert!(reused);
}
#[tokio::test]
async fn test_connect_tls() {
let connector = TransportConnector::new(None);
let mut peer = BasicPeer::new("1.1.1.1:443");
// BasicPeer will use tls when SNI is set
peer.sni = "one.one.one.one".to_string();
// make a new connection to https://1.1.1.1
let stream = connector.new_stream(&peer).await.unwrap();
connector.release_stream(stream, peer.reuse_hash(), None);
let (_, reused) = connector.get_stream(&peer).await.unwrap();
assert!(reused);
}
async fn do_test_conn_timeout(conf: Option<ConnectorOptions>) {
let connector = TransportConnector::new(conf);
let mut peer = BasicPeer::new(BLACK_HOLE);
peer.options.connection_timeout = Some(std::time::Duration::from_millis(1));
let stream = connector.new_stream(&peer).await;
match stream {
Ok(_) => panic!("should throw an error"),
Err(e) => assert_eq!(e.etype(), &ConnectTimedout),
}
}
#[tokio::test]
async fn test_conn_timeout() {
do_test_conn_timeout(None).await;
}
#[tokio::test]
async fn test_conn_timeout_with_offload() {
let mut conf = ConnectorOptions::new(8);
conf.offload_threadpool = Some((2, 2));
do_test_conn_timeout(Some(conf)).await;
}
#[tokio::test]
async fn test_connector_bind_to() {
// connect to remote while bind to localhost will fail
let peer = BasicPeer::new("240.0.0.1:80");
let mut conf = ConnectorOptions::new(1);
conf.bind_to_v4.push("127.0.0.1:0".parse().unwrap());
let connector = TransportConnector::new(Some(conf));
let stream = connector.new_stream(&peer).await;
let error = stream.unwrap_err();
// XXX: some system will allow the socket to bind and connect without error, only to timeout
assert!(error.etype() == &ConnectError || error.etype() == &ConnectTimedout)
}
/// Helper function for testing error handling in the `do_connect` function.
/// This assumes that the connection will fail to on the peer and returns
/// the decomposed error type and message
async fn get_do_connect_failure_with_peer(peer: &BasicPeer) -> (ErrorType, String) {
let ssl_connector = SslConnector::builder(SslMethod::tls()).unwrap().build();
let stream = do_connect(peer, None, None, &ssl_connector).await;
match stream {
Ok(_) => panic!("should throw an error"),
Err(e) => (
e.etype().clone(),
e.context
.as_ref()
.map(|ctx| ctx.as_str().to_owned())
.unwrap_or_default(),
),
}
}
#[tokio::test]
async fn test_do_connect_with_total_timeout() {
let mut peer = BasicPeer::new(BLACK_HOLE);
peer.options.total_connection_timeout = Some(std::time::Duration::from_millis(1));
let (etype, context) = get_do_connect_failure_with_peer(&peer).await;
assert_eq!(etype, ConnectTimedout);
assert!(context.contains("total-connection timeout"));
}
#[tokio::test]
async fn test_tls_connect_timeout_supersedes_total() {
let mut peer = BasicPeer::new(BLACK_HOLE);
peer.options.total_connection_timeout = Some(std::time::Duration::from_millis(10));
peer.options.connection_timeout = Some(std::time::Duration::from_millis(1));
let (etype, context) = get_do_connect_failure_with_peer(&peer).await;
assert_eq!(etype, ConnectTimedout);
assert!(!context.contains("total-connection timeout"));
}
#[tokio::test]
async fn test_do_connect_without_total_timeout() {
let peer = BasicPeer::new(BLACK_HOLE);
let (etype, context) = get_do_connect_failure_with_peer(&peer).await;
assert!(etype != ConnectTimedout || !context.contains("total-connection timeout"));
}
}

View file

@ -0,0 +1,77 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use log::debug;
use once_cell::sync::OnceCell;
use rand::Rng;
use tokio::runtime::{Builder, Handle};
use tokio::sync::oneshot::{channel, Sender};
// TODO: use pingora_runtime
// a shared runtime (thread pools)
pub(crate) struct OffloadRuntime {
shards: usize,
thread_per_shard: usize,
// Lazily init the runtimes so that they are created after pingora
// daemonize itself. Otherwise the runtime threads are lost.
pools: OnceCell<Box<[(Handle, Sender<()>)]>>,
}
impl OffloadRuntime {
pub fn new(shards: usize, thread_per_shard: usize) -> Self {
assert!(shards != 0);
assert!(thread_per_shard != 0);
OffloadRuntime {
shards,
thread_per_shard,
pools: OnceCell::new(),
}
}
fn init_pools(&self) -> Box<[(Handle, Sender<()>)]> {
let threads = self.shards * self.thread_per_shard;
let mut pools = Vec::with_capacity(threads);
for _ in 0..threads {
// We use single thread runtimes to reduce the scheduling overhead of multithread
// tokio runtime, which can be 50% of the on CPU time of the runtimes
let rt = Builder::new_current_thread().enable_all().build().unwrap();
let handler = rt.handle().clone();
let (tx, rx) = channel::<()>();
std::thread::Builder::new()
.name("Offload thread".to_string())
.spawn(move || {
debug!("Offload thread started");
// the thread that calls block_on() will drive the runtime
// rx will return when tx is dropped so this runtime and thread will exit
rt.block_on(rx)
})
.unwrap();
pools.push((handler, tx));
}
pools.into_boxed_slice()
}
pub fn get_runtime(&self, hash: u64) -> &Handle {
let mut rng = rand::thread_rng();
// choose a shard based on hash and a random thread with in that shard
// e.g. say thread_per_shard=2, shard 1 thread 1 is 1 * 2 + 1 = 3
// [[th0, th1], [th2, th3], ...]
let shard = hash as usize % self.shards;
let thread_in_shard = rng.gen_range(0..self.thread_per_shard);
let pools = self.pools.get_or_init(|| self.init_pools());
&pools[shard * self.thread_per_shard + thread_in_shard].0
}
}

View file

@ -0,0 +1,309 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use log::debug;
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use std::sync::{Arc, Once};
use super::ConnectorOptions;
use crate::protocols::ssl::client::handshake;
use crate::protocols::ssl::SslStream;
use crate::protocols::IO;
use crate::tls::ext::{
add_host, clear_error_stack, ssl_add_chain_cert, ssl_set_groups_list,
ssl_set_renegotiate_mode_freely, ssl_set_verify_cert_store, ssl_use_certificate,
ssl_use_private_key, ssl_use_second_key_share,
};
#[cfg(feature = "boringssl")]
use crate::tls::ssl::SslCurve;
use crate::tls::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode, SslVersion};
use crate::tls::x509::store::X509StoreBuilder;
use crate::upstreams::peer::{Peer, ALPN};
const CIPHER_LIST: &str = "AES-128-GCM-SHA256\
:AES-256-GCM-SHA384\
:CHACHA20-POLY1305-SHA256\
:ECDHE-ECDSA-AES128-GCM-SHA256\
:ECDHE-ECDSA-AES256-GCM-SHA384\
:ECDHE-RSA-AES128-GCM-SHA256\
:ECDHE-RSA-AES256-GCM-SHA384\
:ECDHE-RSA-AES128-SHA\
:ECDHE-RSA-AES256-SHA384\
:AES128-GCM-SHA256\
:AES256-GCM-SHA384\
:AES128-SHA\
:AES256-SHA\
:DES-CBC3-SHA";
/**
* Enabled signature algorithms for signing/verification (ECDSA).
* As of 4/10/2023, the only addition to boringssl's defaults is ECDSA_SECP521R1_SHA512.
*/
const SIGALG_LIST: &str = "ECDSA_SECP256R1_SHA256\
:RSA_PSS_RSAE_SHA256\
:RSA_PKCS1_SHA256\
:ECDSA_SECP384R1_SHA384\
:RSA_PSS_RSAE_SHA384\
:RSA_PKCS1_SHA384\
:RSA_PSS_RSAE_SHA512\
:RSA_PKCS1_SHA512\
:RSA_PKCS1_SHA1\
:ECDSA_SECP521R1_SHA512";
/**
* Enabled curves for ECDHE (signature key exchange).
* As of 4/10/2023, the only addition to boringssl's defaults is SECP521R1.
*
* N.B. The ordering of these curves is important. The boringssl library will select the first one
* as a guess when negotiating a handshake with a server using TLSv1.3. We should opt for curves
* that are both computationally cheaper and more supported.
*/
#[cfg(feature = "boringssl")]
const BORINGSSL_CURVE_LIST: &[SslCurve] = &[
SslCurve::X25519,
SslCurve::SECP256R1,
SslCurve::SECP384R1,
SslCurve::SECP521R1,
];
static INIT_CA_ENV: Once = Once::new();
fn init_ssl_cert_env_vars() {
// this sets env vars to pick up the root certs
// it is universal across openssl and boringssl
INIT_CA_ENV.call_once(openssl_probe::init_ssl_cert_env_vars);
}
#[derive(Clone)]
pub struct Connector {
pub(crate) ctx: Arc<SslConnector>, // Arc to support clone
}
impl Connector {
pub fn new(options: Option<ConnectorOptions>) -> Self {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
// TODO: make these conf
// Set supported ciphers.
builder.set_cipher_list(CIPHER_LIST).unwrap();
// Set supported signature algorithms and ECDH (key exchange) curves.
builder
.set_sigalgs_list(&SIGALG_LIST.to_lowercase())
.unwrap();
#[cfg(feature = "boringssl")]
builder.set_curves(BORINGSSL_CURVE_LIST).unwrap();
builder
.set_max_proto_version(Some(SslVersion::TLS1_3))
.unwrap();
builder
.set_min_proto_version(Some(SslVersion::TLS1))
.unwrap();
if let Some(conf) = options.as_ref() {
if let Some(ca_file_path) = conf.ca_file.as_ref() {
builder.set_ca_file(ca_file_path).unwrap();
} else {
init_ssl_cert_env_vars();
// load from default system wide trust location. (the name is misleading)
builder.set_default_verify_paths().unwrap();
}
if let Some((cert, key)) = conf.cert_key_file.as_ref() {
builder
.set_certificate_file(cert, SslFiletype::PEM)
.unwrap();
builder.set_private_key_file(key, SslFiletype::PEM).unwrap();
}
} else {
init_ssl_cert_env_vars();
builder.set_default_verify_paths().unwrap();
}
Connector {
ctx: Arc::new(builder.build()),
}
}
}
/*
OpenSSL considers underscores in hostnames non-compliant.
We replace the underscore in the leftmost label as we must support these
hostnames for wildcard matches and we have not patched OpenSSL.
https://github.com/openssl/openssl/issues/12566
> The labels must follow the rules for ARPANET host names. They must
> start with a letter, end with a letter or digit, and have as interior
> characters only letters, digits, and hyphen. There are also some
> restrictions on the length. Labels must be 63 characters or less.
- https://datatracker.ietf.org/doc/html/rfc1034#section-3.5
*/
fn replace_leftmost_underscore(sni: &str) -> Option<String> {
// wildcard is only leftmost label
let mut s = sni.splitn(2, '.');
if let (Some(leftmost), Some(rest)) = (s.next(), s.next()) {
// if not a subdomain or leftmost does not contain underscore return
if !rest.contains('.') || !leftmost.contains('_') {
return None;
}
// we have a subdomain, replace underscores
let leftmost = leftmost.replace('_', "-");
return Some(format!("{leftmost}.{rest}"));
}
None
}
pub(crate) async fn connect<T, P>(
stream: T,
peer: &P,
alpn_override: Option<ALPN>,
tls_ctx: &SslConnector,
) -> Result<SslStream<T>>
where
T: IO,
P: Peer + Send + Sync,
{
let mut ssl_conf = tls_ctx.configure().unwrap();
ssl_set_renegotiate_mode_freely(&mut ssl_conf);
// Set up CA/verify cert store
// TODO: store X509Store in the peer directly
if let Some(ca_list) = peer.get_ca() {
let mut store_builder = X509StoreBuilder::new().unwrap();
for ca in &***ca_list {
store_builder.add_cert(ca.clone()).unwrap();
}
ssl_set_verify_cert_store(&mut ssl_conf, &store_builder.build())
.or_err(InternalError, "failed to load cert store")?;
}
// Set up client cert/key
if let Some(key_pair) = peer.get_client_cert_key() {
debug!("setting client cert and key");
ssl_use_certificate(&mut ssl_conf, key_pair.leaf())
.or_err(InternalError, "invalid client cert")?;
ssl_use_private_key(&mut ssl_conf, key_pair.key())
.or_err(InternalError, "invalid client key")?;
let intermediates = key_pair.intermediates();
if !intermediates.is_empty() {
debug!("adding intermediate certificates for mTLS chain");
for int in intermediates {
ssl_add_chain_cert(&mut ssl_conf, int)
.or_err(InternalError, "invalid intermediate client cert")?;
}
}
}
if let Some(curve) = peer.get_peer_options().and_then(|o| o.curves) {
ssl_set_groups_list(&mut ssl_conf, curve).or_err(InternalError, "invalid curves")?;
}
// second_keyshare is default true
if !peer.get_peer_options().map_or(true, |o| o.second_keyshare) {
ssl_use_second_key_share(&mut ssl_conf, false);
}
// disable verification if sni does not exist
// XXX: verify on empty string cause null string seg fault
if peer.sni().is_empty() {
ssl_conf.set_use_server_name_indication(false);
/* NOTE: technically we can still verify who signs the cert but turn it off to be
consistant with nginx's behavior */
ssl_conf.set_verify(SslVerifyMode::NONE);
} else if peer.verify_cert() {
if peer.verify_hostname() {
let verify_param = ssl_conf.param_mut();
add_host(verify_param, peer.sni()).or_err(InternalError, "failed to add host")?;
// if sni had underscores in leftmost label replace and add
if let Some(sni_s) = replace_leftmost_underscore(peer.sni()) {
add_host(verify_param, sni_s.as_ref()).unwrap();
}
if let Some(alt_cn) = peer.alternative_cn() {
if !alt_cn.is_empty() {
add_host(verify_param, alt_cn).unwrap();
// if alt_cn had underscores in leftmost label replace and add
if let Some(alt_cn_s) = replace_leftmost_underscore(alt_cn) {
add_host(verify_param, alt_cn_s.as_ref()).unwrap();
}
}
}
}
ssl_conf.set_verify(SslVerifyMode::PEER);
} else {
ssl_conf.set_verify(SslVerifyMode::NONE);
}
/*
We always set set_verify_hostname(false) here because:
- verify case.) otherwise ssl.connect calls X509_VERIFY_PARAM_set1_host
which overrides the names added by add_host. Verify is
essentially on as long as the names are added.
- off case.) the non verify hostname case should have it disabled
*/
ssl_conf.set_verify_hostname(false);
if let Some(alpn) = alpn_override.as_ref().or(peer.get_alpn()) {
ssl_conf.set_alpn_protos(alpn.to_wire_preference()).unwrap();
}
clear_error_stack();
let connect_future = handshake(ssl_conf, peer.sni(), stream);
match peer.connection_timeout() {
Some(t) => match pingora_timeout::timeout(t, connect_future).await {
Ok(res) => res,
Err(_) => Error::e_explain(
ConnectTimedout,
format!("connecting to server {}, timeout {:?}", peer, t),
),
},
None => connect_future.await,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replace_leftmost_underscore() {
let none_cases = [
"",
"some",
"some.com",
"1.1.1.1:5050",
"dog.dot.com",
"dog.d_t.com",
"dog.dot.c_m",
"d_g.com",
"_",
"dog.c_m",
];
for case in none_cases {
assert!(replace_leftmost_underscore(case).is_none(), "{}", case);
}
assert_eq!(
Some("bb-b.some.com".to_string()),
replace_leftmost_underscore("bb_b.some.com")
);
assert_eq!(
Some("a-a-a.some.com".to_string()),
replace_leftmost_underscore("a_a_a.some.com")
);
assert_eq!(
Some("-.some.com".to_string()),
replace_leftmost_underscore("_.some.com")
);
}
}

69
pingora-core/src/lib.rs Normal file
View file

@ -0,0 +1,69 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![warn(clippy::all)]
#![allow(clippy::new_without_default)]
#![allow(clippy::type_complexity)]
#![allow(clippy::match_wild_err_arm)]
#![allow(clippy::missing_safety_doc)]
#![allow(clippy::upper_case_acronyms)]
// enable nightly feature async trait so that the docs are cleaner
#![cfg_attr(doc_async_trait, feature(async_fn_in_trait))]
//! # Pingora
//!
//! Pingora is a collection of service frameworks and network libraries battle-tested by the Internet.
//! It is to build robust, scalable and secure network infrastructures and services at Internet scale.
//!
//! # Features
//! - Http 1.x and Http 2
//! - Modern TLS with OpenSSL or BoringSSL (FIPS compatible)
//! - Zero downtime upgrade
//!
//! # Usage
//! This crate provides low level service and protocol implementation and abstraction.
//!
//! If looking to build a (reverse) proxy, see `pingora-proxy` crate.
//!
//! # Optional features
//! `boringssl`: Switch the internal TLS library from OpenSSL to BoringSSL.
pub mod apps;
pub mod connectors;
pub mod listeners;
pub mod modules;
pub mod protocols;
pub mod server;
pub mod services;
pub mod upstreams;
pub mod utils;
pub use pingora_error::{ErrorType::*, *};
// If both openssl and boringssl are enabled, prefer boringssl.
// This is to make sure that boringssl can override the default openssl feature
// when this crate is used indirectly by other crates.
#[cfg(feature = "boringssl")]
pub use pingora_boringssl as tls;
#[cfg(all(not(feature = "boringssl"), feature = "openssl"))]
pub use pingora_openssl as tls;
pub mod prelude {
pub use crate::server::configuration::Opt;
pub use crate::server::Server;
pub use crate::services::background::background_service;
pub use crate::upstreams::peer::HttpPeer;
pub use pingora_error::{ErrorType::*, *};
}

View file

@ -0,0 +1,311 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use log::warn;
use pingora_error::{
ErrorType::{AcceptError, BindError},
OrErr, Result,
};
use std::fs::Permissions;
use std::io::ErrorKind;
use std::net::{SocketAddr, ToSocketAddrs};
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net::UnixListener as StdUnixListener;
use std::time::Duration;
use tokio::net::TcpSocket;
use crate::protocols::l4::listener::Listener;
pub use crate::protocols::l4::stream::Stream;
use crate::server::ListenFds;
const TCP_LISTENER_MAX_TRY: usize = 30;
const TCP_LISTENER_TRY_STEP: Duration = Duration::from_secs(1);
// TODO: configurable backlog
const LISTENER_BACKLOG: u32 = 65535;
/// Address for listening server, either TCP/UDS socket.
#[derive(Clone, Debug)]
pub enum ServerAddress {
Tcp(String, Option<TcpSocketOptions>),
Uds(String, Option<Permissions>),
}
impl AsRef<str> for ServerAddress {
fn as_ref(&self) -> &str {
match &self {
Self::Tcp(l, _) => l,
Self::Uds(l, _) => l,
}
}
}
/// TCP socket configuration options.
#[derive(Clone, Debug)]
pub struct TcpSocketOptions {
/// IPV6_V6ONLY flag (if true, limit socket to IPv6 communication only).
/// This is mostly useful when binding to `[::]`, which on most Unix distributions
/// will bind to both IPv4 and IPv6 addresses by default.
pub ipv6_only: bool,
// TODO: allow configuring reuseaddr, backlog, etc. from here?
}
mod uds {
use super::{OrErr, Result};
use crate::protocols::l4::listener::Listener;
use log::{debug, error};
use pingora_error::ErrorType::BindError;
use std::fs::{self, Permissions};
use std::io::ErrorKind;
use std::os::unix::fs::PermissionsExt;
use std::os::unix::net::UnixListener as StdUnixListener;
use tokio::net::UnixListener;
use super::LISTENER_BACKLOG;
pub(super) fn set_perms(path: &str, perms: Option<Permissions>) -> Result<()> {
// set read/write permissions for all users on the socket by default
let perms = perms.unwrap_or(Permissions::from_mode(0o666));
fs::set_permissions(path, perms).or_err_with(BindError, || {
format!("Fail to bind to {path}, could not set permissions")
})
}
pub(super) fn set_backlog(l: StdUnixListener, backlog: u32) -> Result<UnixListener> {
let socket: socket2::Socket = l.into();
// Note that we call listen on an already listening socket
// POSIX undefined but on Linux it will update the backlog size
socket
.listen(backlog as i32)
.or_err_with(BindError, || format!("listen() failed on {socket:?}"))?;
UnixListener::from_std(socket.into()).or_err(BindError, "Failed to convert to tokio socket")
}
pub(super) fn bind(addr: &str, perms: Option<Permissions>) -> Result<Listener> {
/*
We remove the filename/address in case there is a dangling reference.
"Binding to a socket with a filename creates a socket in the
filesystem that must be deleted by the caller when it is no
longer needed (using unlink(2))"
*/
match std::fs::remove_file(addr) {
Ok(()) => {
debug!("unlink {addr} done");
}
Err(e) => match e.kind() {
ErrorKind::NotFound => debug!("unlink {addr} not found: {e}"),
_ => error!("unlink {addr} failed: {e}"),
},
}
let listener_socket = UnixListener::bind(addr)
.or_err_with(BindError, || format!("Bind() failed on {addr}"))?;
set_perms(addr, perms)?;
let std_listener = listener_socket.into_std().unwrap();
Ok(set_backlog(std_listener, LISTENER_BACKLOG)?.into())
}
}
// currently, these options can only apply on sockets prior to calling bind()
fn apply_tcp_socket_options(sock: &TcpSocket, opt: Option<&TcpSocketOptions>) -> Result<()> {
let Some(opt) = opt else {
return Ok(());
};
let socket_ref = socket2::SockRef::from(sock);
socket_ref
.set_only_v6(opt.ipv6_only)
.or_err(BindError, "failed to set IPV6_V6ONLY")
}
fn from_raw_fd(address: &ServerAddress, fd: i32) -> Result<Listener> {
match address {
ServerAddress::Uds(addr, perm) => {
let std_listener = unsafe { StdUnixListener::from_raw_fd(fd) };
// set permissions just in case
uds::set_perms(addr, perm.clone())?;
Ok(uds::set_backlog(std_listener, LISTENER_BACKLOG)?.into())
}
ServerAddress::Tcp(_, _) => {
let std_listener_socket = unsafe { std::net::TcpStream::from_raw_fd(fd) };
let listener_socket = TcpSocket::from_std_stream(std_listener_socket);
// Note that we call listen on an already listening socket
// POSIX undefined but on Linux it will update the backlog size
Ok(listener_socket
.listen(LISTENER_BACKLOG)
.or_err_with(BindError, || format!("Listen() failed on {address:?}"))?
.into())
}
}
}
async fn bind_tcp(addr: &str, opt: Option<TcpSocketOptions>) -> Result<Listener> {
let mut try_count = 0;
loop {
let sock_addr = addr
.to_socket_addrs() // NOTE: this could invoke a blocking network lookup
.or_err_with(BindError, || format!("Invalid listen address {addr}"))?
.next() // take the first one for now
.unwrap(); // assume there is always at least one
let listener_socket = match sock_addr {
SocketAddr::V4(_) => TcpSocket::new_v4(),
SocketAddr::V6(_) => TcpSocket::new_v6(),
}
.or_err_with(BindError, || format!("fail to create address {sock_addr}"))?;
// NOTE: this is to preserve the current TcpListener::bind() behavior.
// We have a few test relying on this behavior to allow multiple identical
// test servers to coexist.
listener_socket
.set_reuseaddr(true)
.or_err(BindError, "fail to set_reuseaddr(true)")?;
apply_tcp_socket_options(&listener_socket, opt.as_ref())?;
match listener_socket.bind(sock_addr) {
Ok(()) => {
break Ok(listener_socket
.listen(LISTENER_BACKLOG)
.or_err(BindError, "bind() failed")?
.into())
}
Err(e) => {
if e.kind() != ErrorKind::AddrInUse {
break Err(e).or_err_with(BindError, || format!("bind() failed on {addr}"));
}
try_count += 1;
if try_count >= TCP_LISTENER_MAX_TRY {
break Err(e).or_err_with(BindError, || {
format!("bind() failed, after retries, {addr} still in use")
});
}
warn!("{addr} is in use, will try again");
tokio::time::sleep(TCP_LISTENER_TRY_STEP).await;
}
}
}
}
async fn bind(addr: &ServerAddress) -> Result<Listener> {
match addr {
ServerAddress::Uds(l, perm) => uds::bind(l, perm.clone()),
ServerAddress::Tcp(l, opt) => bind_tcp(l, opt.clone()).await,
}
}
pub struct ListenerEndpoint {
listen_addr: ServerAddress,
listener: Option<Listener>,
}
impl ListenerEndpoint {
pub fn new(listen_addr: ServerAddress) -> Self {
ListenerEndpoint {
listen_addr,
listener: None,
}
}
pub fn as_str(&self) -> &str {
self.listen_addr.as_ref()
}
pub async fn listen(&mut self, fds: Option<ListenFds>) -> Result<()> {
if self.listener.is_some() {
return Ok(());
}
let listener = if let Some(fds_table) = fds {
let addr = self.listen_addr.as_ref();
// consider make this mutex std::sync::Mutex or OnceCell
let mut table = fds_table.lock().await;
if let Some(fd) = table.get(addr.as_ref()) {
from_raw_fd(&self.listen_addr, *fd)?
} else {
// not found
let listener = bind(&self.listen_addr).await?;
table.add(addr.to_string(), listener.as_raw_fd());
listener
}
} else {
// not found, no fd table
bind(&self.listen_addr).await?
};
self.listener = Some(listener);
Ok(())
}
pub async fn accept(&mut self) -> Result<Stream> {
let Some(listener) = self.listener.as_mut() else {
// panic otherwise this thing dead loop
panic!("Need to call listen() first");
};
let mut stream = listener
.accept()
.await
.or_err(AcceptError, "Fail to accept()")?;
stream.set_nodelay()?;
Ok(stream)
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn test_listen_tcp() {
let addr = "127.0.0.1:7100";
let mut listener = ListenerEndpoint::new(ServerAddress::Tcp(addr.into(), None));
listener.listen(None).await.unwrap();
tokio::spawn(async move {
// just try to accept once
listener.accept().await.unwrap();
});
tokio::net::TcpStream::connect(addr)
.await
.expect("can connect to TCP listener");
}
#[tokio::test]
async fn test_listen_tcp_ipv6_only() {
let sock_opt = Some(TcpSocketOptions { ipv6_only: true });
let mut listener = ListenerEndpoint::new(ServerAddress::Tcp("[::]:7101".into(), sock_opt));
listener.listen(None).await.unwrap();
tokio::spawn(async move {
// just try to accept twice
listener.accept().await.unwrap();
listener.accept().await.unwrap();
});
tokio::net::TcpStream::connect("127.0.0.1:7101")
.await
.expect_err("cannot connect to v4 addr");
tokio::net::TcpStream::connect("[::1]:7101")
.await
.expect("can connect to v6 addr");
}
#[tokio::test]
async fn test_listen_uds() {
let addr = "/tmp/test_listen_uds";
let mut listener = ListenerEndpoint::new(ServerAddress::Uds(addr.into(), None));
listener.listen(None).await.unwrap();
tokio::spawn(async move {
// just try to accept once
listener.accept().await.unwrap();
});
tokio::net::UnixStream::connect(addr)
.await
.expect("can connect to UDS listener");
}
}

View file

@ -0,0 +1,248 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! The listening endpoints (TCP and TLS) and their configurations.
mod l4;
mod tls;
use crate::protocols::Stream;
use crate::server::ListenFds;
use pingora_error::Result;
use std::{fs::Permissions, sync::Arc};
use l4::{ListenerEndpoint, Stream as L4Stream};
use tls::Acceptor;
pub use crate::protocols::ssl::server::TlsAccept;
pub use l4::{ServerAddress, TcpSocketOptions};
pub use tls::{TlsSettings, ALPN};
struct TransportStackBuilder {
l4: ServerAddress,
tls: Option<TlsSettings>,
}
impl TransportStackBuilder {
pub fn build(&mut self, upgrade_listeners: Option<ListenFds>) -> TransportStack {
TransportStack {
l4: ListenerEndpoint::new(self.l4.clone()),
tls: self.tls.take().map(|tls| Arc::new(tls.build())),
upgrade_listeners,
}
}
}
pub(crate) struct TransportStack {
l4: ListenerEndpoint,
tls: Option<Arc<Acceptor>>,
// listeners sent from the old process for graceful upgrade
upgrade_listeners: Option<ListenFds>,
}
impl TransportStack {
pub fn as_str(&self) -> &str {
self.l4.as_str()
}
pub async fn listen(&mut self) -> Result<()> {
self.l4.listen(self.upgrade_listeners.take()).await
}
pub async fn accept(&mut self) -> Result<UninitializedStream> {
let stream = self.l4.accept().await?;
Ok(UninitializedStream {
l4: stream,
tls: self.tls.clone(),
})
}
pub fn cleanup(&mut self) {
// placeholder
}
}
pub(crate) struct UninitializedStream {
l4: L4Stream,
tls: Option<Arc<Acceptor>>,
}
impl UninitializedStream {
pub async fn handshake(self) -> Result<Stream> {
if let Some(tls) = self.tls {
let tls_stream = tls.tls_handshake(self.l4).await?;
Ok(Box::new(tls_stream))
} else {
Ok(Box::new(self.l4))
}
}
}
/// The struct to hold one more multiple listening endpoints
pub struct Listeners {
stacks: Vec<TransportStackBuilder>,
}
impl Listeners {
/// Create a new [`Listeners`] with no listening endpoints.
pub fn new() -> Self {
Listeners { stacks: vec![] }
}
/// Create a new [`Listeners`] with a TCP server endpoint from the given string.
pub fn tcp(addr: &str) -> Self {
let mut listeners = Self::new();
listeners.add_tcp(addr);
listeners
}
/// Create a new [`Listeners`] with a Unix domain socket endpoint from the given string.
pub fn uds(addr: &str, perm: Option<Permissions>) -> Self {
let mut listeners = Self::new();
listeners.add_uds(addr, perm);
listeners
}
/// Create a new [`Listeners`] with with a TLS (TCP) endpoint with the given address string,
/// and path to the certificate/private key pairs.
/// This endpoint will adopt the [Mozilla Intermediate](https://wiki.mozilla.org/Security/Server_Side_TLS#Intermediate_compatibility_.28recommended.29)
/// server side TLS settings.
pub fn tls(addr: &str, cert_path: &str, key_path: &str) -> Result<Self> {
let mut listeners = Self::new();
listeners.add_tls(addr, cert_path, key_path)?;
Ok(listeners)
}
/// Add a TCP endpoint to `self`.
pub fn add_tcp(&mut self, addr: &str) {
self.add_address(ServerAddress::Tcp(addr.into(), None));
}
/// Add a TCP endpoint to `self`, with the given [`TcpSocketOptions`].
pub fn add_tcp_with_settings(&mut self, addr: &str, sock_opt: TcpSocketOptions) {
self.add_address(ServerAddress::Tcp(addr.into(), Some(sock_opt)));
}
/// Add a Unix domain socket endpoint to `self`.
pub fn add_uds(&mut self, addr: &str, perm: Option<Permissions>) {
self.add_address(ServerAddress::Uds(addr.into(), perm));
}
/// Add a TLS endpoint to `self` with the [Mozilla Intermediate](https://wiki.mozilla.org/Security/Server_Side_TLS#Intermediate_compatibility_.28recommended.29)
/// server side TLS settings.
pub fn add_tls(&mut self, addr: &str, cert_path: &str, key_path: &str) -> Result<()> {
self.add_tls_with_settings(addr, None, TlsSettings::intermediate(cert_path, key_path)?);
Ok(())
}
/// Add a TLS endpoint to `self` with the given socket and server side TLS settings.
/// See [`TlsSettings`] and [`TcpSocketOptions`] for more details.
pub fn add_tls_with_settings(
&mut self,
addr: &str,
sock_opt: Option<TcpSocketOptions>,
settings: TlsSettings,
) {
self.add_endpoint(ServerAddress::Tcp(addr.into(), sock_opt), Some(settings));
}
/// Add the given [`ServerAddress`] to `self`.
pub fn add_address(&mut self, addr: ServerAddress) {
self.add_endpoint(addr, None);
}
/// Add the given [`ServerAddress`] to `self` with the given [`TlsSettings`] if provided
pub fn add_endpoint(&mut self, l4: ServerAddress, tls: Option<TlsSettings>) {
self.stacks.push(TransportStackBuilder { l4, tls })
}
pub(crate) fn build(&mut self, upgrade_listeners: Option<ListenFds>) -> Vec<TransportStack> {
self.stacks
.iter_mut()
.map(|b| b.build(upgrade_listeners.clone()))
.collect()
}
pub(crate) fn cleanup(&self) {
// placeholder
}
}
#[cfg(test)]
mod test {
use super::*;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn test_listen_tcp() {
let addr1 = "127.0.0.1:7101";
let addr2 = "127.0.0.1:7102";
let mut listeners = Listeners::tcp(addr1);
listeners.add_tcp(addr2);
let listeners = listeners.build(None);
assert_eq!(listeners.len(), 2);
for mut listener in listeners {
tokio::spawn(async move {
listener.listen().await.unwrap();
// just try to accept once
let stream = listener.accept().await.unwrap();
stream.handshake().await.unwrap();
});
}
// make sure the above starts before the lines below
sleep(Duration::from_millis(10)).await;
TcpStream::connect(addr1).await.unwrap();
TcpStream::connect(addr2).await.unwrap();
}
#[tokio::test]
async fn test_listen_tls() {
use tokio::io::AsyncReadExt;
let addr = "127.0.0.1:7103";
let cert_path = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR"));
let key_path = format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR"));
let mut listeners = Listeners::tls(addr, &cert_path, &key_path).unwrap();
let mut listener = listeners.build(None).pop().unwrap();
tokio::spawn(async move {
listener.listen().await.unwrap();
// just try to accept once
let stream = listener.accept().await.unwrap();
let mut stream = stream.handshake().await.unwrap();
let mut buf = [0; 1024];
let _ = stream.read(&mut buf).await.unwrap();
stream
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\na")
.await
.unwrap();
});
// make sure the above starts before the lines below
sleep(Duration::from_millis(10)).await;
let client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let res = client.get(format!("https://{addr}")).send().await.unwrap();
assert_eq!(res.status(), reqwest::StatusCode::OK);
}
}

View file

@ -0,0 +1,152 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use log::debug;
use pingora_error::{ErrorType, OrErr, Result};
use std::ops::{Deref, DerefMut};
use crate::protocols::ssl::{
server::{handshake, handshake_with_callback, TlsAcceptCallbacks},
SslStream,
};
use crate::protocols::IO;
use crate::tls::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod};
pub use crate::protocols::ssl::ALPN;
pub const TLS_CONF_ERR: ErrorType = ErrorType::Custom("TLSConfigError");
pub(crate) struct Acceptor {
ssl_acceptor: SslAcceptor,
callbacks: Option<TlsAcceptCallbacks>,
}
/// The TLS settings of a listening endpoint
pub struct TlsSettings {
accept_builder: SslAcceptorBuilder,
callbacks: Option<TlsAcceptCallbacks>,
}
impl Deref for TlsSettings {
type Target = SslAcceptorBuilder;
fn deref(&self) -> &Self::Target {
&self.accept_builder
}
}
impl DerefMut for TlsSettings {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.accept_builder
}
}
impl TlsSettings {
/// Create a new [`TlsSettings`] with the the [Mozilla Intermediate](https://wiki.mozilla.org/Security/Server_Side_TLS#Intermediate_compatibility_.28recommended.29).
/// server side TLS settings. Users can adjust the TLS settings after this object is created.
/// Return error if the provided certificate and private key are invalid or not found.
pub fn intermediate(cert_path: &str, key_path: &str) -> Result<Self> {
let mut accept_builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).or_err(
TLS_CONF_ERR,
"fail to create mozilla_intermediate_v5 Acceptor",
)?;
accept_builder
.set_private_key_file(key_path, SslFiletype::PEM)
.or_err(TLS_CONF_ERR, "fail to read key file {key_path}")?;
accept_builder
.set_certificate_chain_file(cert_path)
.or_err(TLS_CONF_ERR, "fail to read cert file {cert_path}")?;
Ok(TlsSettings {
accept_builder,
callbacks: None,
})
}
/// Create a new [`TlsSettings`] similar to [TlsSettings::intermediate()]. A struct that implements [TlsAcceptCallbacks]
/// is needed to provide the certificate during the TLS handshake.
pub fn with_callbacks(callbacks: TlsAcceptCallbacks) -> Result<Self> {
let accept_builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).or_err(
TLS_CONF_ERR,
"fail to create mozilla_intermediate_v5 Acceptor",
)?;
Ok(TlsSettings {
accept_builder,
callbacks: Some(callbacks),
})
}
/// Enable HTTP/2 support for this endpoint, which is default off.
/// This effectively sets the ALPN to prefer HTTP/2 with HTTP/1.1 allowed
pub fn enable_h2(&mut self) {
self.set_alpn(ALPN::H2H1);
}
/// Set the ALPN preference of this endpoint. See [`ALPN`] for more details
pub fn set_alpn(&mut self, alpn: ALPN) {
match alpn {
ALPN::H2H1 => self
.accept_builder
.set_alpn_select_callback(alpn::prefer_h2),
ALPN::H1 => self.accept_builder.set_alpn_select_callback(alpn::h1_only),
ALPN::H2 => self.accept_builder.set_alpn_select_callback(alpn::h2_only),
}
}
pub(crate) fn build(self) -> Acceptor {
Acceptor {
ssl_acceptor: self.accept_builder.build(),
callbacks: self.callbacks,
}
}
}
impl Acceptor {
pub async fn tls_handshake<S: IO>(&self, stream: S) -> Result<SslStream<S>> {
debug!("new ssl session");
// TODO: be able to offload this handshake in a thread pool
if let Some(cb) = self.callbacks.as_ref() {
handshake_with_callback(&self.ssl_acceptor, stream, cb).await
} else {
handshake(&self.ssl_acceptor, stream).await
}
}
}
mod alpn {
use super::*;
use crate::tls::ssl::{select_next_proto, AlpnError, SslRef};
// A standard implementation provided by the SSL lib is used below
pub fn prefer_h2<'a>(_ssl: &mut SslRef, alpn_in: &'a [u8]) -> Result<&'a [u8], AlpnError> {
match select_next_proto(ALPN::H2H1.to_wire_preference(), alpn_in) {
Some(p) => Ok(p),
_ => Err(AlpnError::NOACK), // unknown ALPN, just ignore it. Most clients will fallback to h1
}
}
pub fn h1_only<'a>(_ssl: &mut SslRef, alpn_in: &'a [u8]) -> Result<&'a [u8], AlpnError> {
match select_next_proto(ALPN::H1.to_wire_preference(), alpn_in) {
Some(p) => Ok(p),
_ => Err(AlpnError::NOACK), // unknown ALPN, just ignore it. Most clients will fallback to h1
}
}
pub fn h2_only<'a>(_ssl: &mut SslRef, alpn_in: &'a [u8]) -> Result<&'a [u8], AlpnError> {
match select_next_proto(ALPN::H2.to_wire_preference(), alpn_in) {
Some(p) => Ok(p),
_ => Err(AlpnError::ALERT_FATAL), // cannot agree
}
}
}

View file

@ -0,0 +1,65 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HTTP compression filter
use super::*;
use crate::protocols::http::compression::ResponseCompressionCtx;
/// HTTP response compression module
pub struct ResponseCompression(ResponseCompressionCtx);
impl HttpModule for ResponseCompression {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn request_header_filter(&mut self, req: &mut RequestHeader) -> Result<()> {
self.0.request_filter(req);
Ok(())
}
fn response_filter(&mut self, t: &mut HttpTask) -> Result<()> {
self.0.response_filter(t);
Ok(())
}
}
/// The builder for HTTP response compression module
pub struct ResponseCompressionBuilder {
level: u32,
}
impl ResponseCompressionBuilder {
/// Return a [ModuleBuilder] for [ResponseCompression] with the given compression level
pub fn enable(level: u32) -> ModuleBuilder {
Box::new(ResponseCompressionBuilder { level })
}
}
impl HttpModuleBuilder for ResponseCompressionBuilder {
fn init(&self) -> Module {
Box::new(ResponseCompression(ResponseCompressionCtx::new(
self.level, false,
)))
}
fn order(&self) -> i16 {
// run the response filter later than most others filters
i16::MIN / 2
}
}

View file

@ -0,0 +1,277 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Modules for HTTP traffic.
//!
//! [HttpModule]s define request and response filters to use while running an [HttpServer]
//! application.
//! See the [ResponseCompression] module for an example of how to implement a basic module.
pub mod compression;
use crate::protocols::http::HttpTask;
use bytes::Bytes;
use once_cell::sync::OnceCell;
use pingora_error::Result;
use pingora_http::RequestHeader;
use std::any::Any;
use std::any::TypeId;
use std::collections::HashMap;
use std::sync::Arc;
/// The trait a HTTP traffic module needs to implement
// TODO: * async filters for, e.g., 3rd party auth server; * access the connection for, e.g., GeoIP
pub trait HttpModule {
fn request_header_filter(&mut self, _req: &mut RequestHeader) -> Result<()> {
Ok(())
}
fn request_body_filter(&mut self, body: Option<Bytes>) -> Result<Option<Bytes>> {
Ok(body)
}
fn response_filter(&mut self, _t: &mut HttpTask) -> Result<()> {
Ok(())
}
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
type Module = Box<dyn HttpModule + 'static + Send + Sync>;
/// Trait to init the http module ctx for each request
pub trait HttpModuleBuilder {
/// The order the module will run
///
/// The lower the value, the later it runs relative to other filters.
/// If the order of the filter is not important, leave it to the default 0.
fn order(&self) -> i16 {
0
}
/// Initialize and return the per request module context
fn init(&self) -> Module;
}
pub type ModuleBuilder = Box<dyn HttpModuleBuilder + 'static + Send + Sync>;
/// The object to hold multiple http modules
pub struct HttpModules {
modules: Vec<ModuleBuilder>,
module_index: OnceCell<Arc<HashMap<TypeId, usize>>>,
}
impl HttpModules {
/// Create a new [HttpModules]
pub fn new() -> Self {
HttpModules {
modules: vec![],
module_index: OnceCell::new(),
}
}
/// Add a new [ModuleBuilder] to [HttpModules]
///
/// Each type of [HttpModule] can be only added once.
/// # Panic
/// Panic if any [HttpModule] is added more tha once.
pub fn add_module(&mut self, builder: ModuleBuilder) {
if self.module_index.get().is_some() {
// We use a shared module_index the index would be out of sync if we
// add more modules.
panic!("cannot add module after ctx is already built")
}
self.modules.push(builder);
// not the most efficient way but should be fine
// largest order first
self.modules.sort_by_key(|m| -m.order());
}
/// Build the contexts of all the modules added to this [HttpModules]
pub fn build_ctx(&self) -> HttpModuleCtx {
let module_ctx: Vec<_> = self.modules.iter().map(|b| b.init()).collect();
let module_index = self
.module_index
.get_or_init(|| {
let mut module_index = HashMap::with_capacity(self.modules.len());
for (i, c) in module_ctx.iter().enumerate() {
let exist = module_index.insert(c.as_any().type_id(), i);
if exist.is_some() {
panic!("duplicated filters found")
}
}
Arc::new(module_index)
})
.clone();
HttpModuleCtx {
module_ctx,
module_index,
}
}
}
/// The Contexts of multiple modules
///
/// This is the object that will apply all the included modules to a certain HTTP request.
/// The modules are ordered according to their `order()`.
pub struct HttpModuleCtx {
// the modules in the order of execution
module_ctx: Vec<Module>,
// find the module in the vec with its type ID
module_index: Arc<HashMap<TypeId, usize>>,
}
impl HttpModuleCtx {
/// Create an placeholder empty [HttpModuleCtx].
///
/// [HttpModules] should be used to create nonempty [HttpModuleCtx].
pub fn empty() -> Self {
HttpModuleCtx {
module_ctx: vec![],
module_index: Arc::new(HashMap::new()),
}
}
/// Get a ref to [HttpModule] if any.
pub fn get<T: 'static>(&self) -> Option<&T> {
let idx = self.module_index.get(&TypeId::of::<T>())?;
let ctx = &self.module_ctx[*idx];
Some(
ctx.as_any()
.downcast_ref::<T>()
.expect("type should always match"),
)
}
/// Get a mut ref to [HttpModule] if any.
pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
let idx = self.module_index.get(&TypeId::of::<T>())?;
let ctx = &mut self.module_ctx[*idx];
Some(
ctx.as_any_mut()
.downcast_mut::<T>()
.expect("type should always match"),
)
}
/// Run the `request_header_filter` for all the modules according to their orders.
pub fn request_header_filter(&mut self, req: &mut RequestHeader) -> Result<()> {
for filter in self.module_ctx.iter_mut() {
filter.request_header_filter(req)?;
}
Ok(())
}
/// Run the `request_body_filter` for all the modules according to their orders.
pub fn request_body_filter(&mut self, mut body: Option<Bytes>) -> Result<Option<Bytes>> {
for filter in self.module_ctx.iter_mut() {
body = filter.request_body_filter(body)?;
}
Ok(body)
}
/// Run the `response_filter` for all the modules according to their orders.
pub fn response_filter(&mut self, t: &mut HttpTask) -> Result<()> {
for filter in self.module_ctx.iter_mut() {
filter.response_filter(t)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MyModule;
impl HttpModule for MyModule {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn request_header_filter(&mut self, req: &mut RequestHeader) -> Result<()> {
req.insert_header("my-filter", "1")
}
}
struct MyModuleBuilder;
impl HttpModuleBuilder for MyModuleBuilder {
fn order(&self) -> i16 {
1
}
fn init(&self) -> Module {
Box::new(MyModule)
}
}
struct MyOtherModule;
impl HttpModule for MyOtherModule {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn request_header_filter(&mut self, req: &mut RequestHeader) -> Result<()> {
if req.headers.get("my-filter").is_some() {
// if this MyOtherModule runs after MyModule
req.insert_header("my-filter", "2")
} else {
// if this MyOtherModule runs before MyModule
req.insert_header("my-other-filter", "1")
}
}
}
struct MyOtherModuleBuilder;
impl HttpModuleBuilder for MyOtherModuleBuilder {
fn order(&self) -> i16 {
-1
}
fn init(&self) -> Module {
Box::new(MyOtherModule)
}
}
#[test]
fn test_module_get() {
let mut http_module = HttpModules::new();
http_module.add_module(Box::new(MyModuleBuilder));
http_module.add_module(Box::new(MyOtherModuleBuilder));
let mut ctx = http_module.build_ctx();
assert!(ctx.get::<MyModule>().is_some());
assert!(ctx.get::<MyOtherModule>().is_some());
assert!(ctx.get::<usize>().is_none());
assert!(ctx.get_mut::<MyModule>().is_some());
assert!(ctx.get_mut::<MyOtherModule>().is_some());
assert!(ctx.get_mut::<usize>().is_none());
}
#[test]
fn test_module_filter() {
let mut http_module = HttpModules::new();
http_module.add_module(Box::new(MyOtherModuleBuilder));
http_module.add_module(Box::new(MyModuleBuilder));
let mut ctx = http_module.build_ctx();
let mut req = RequestHeader::build("Get", b"/", None).unwrap();
ctx.request_header_filter(&mut req).unwrap();
// MyModule runs before MyOtherModule
assert_eq!(req.headers.get("my-filter").unwrap(), "2");
assert!(req.headers.get("my-other-filter").is_none());
}
}

View file

@ -0,0 +1,16 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Modules to extend the functionalities of pingora services.
pub mod http;

View file

@ -0,0 +1,66 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Extra information about the connection
use std::sync::Arc;
use std::time::SystemTime;
use super::raw_connect::ProxyDigest;
use super::ssl::digest::SslDigest;
/// The information can be extracted from a connection
#[derive(Clone, Debug)]
pub struct Digest {
/// Information regarding the TLS of this connection if any
pub ssl_digest: Option<Arc<SslDigest>>,
/// Timing information
pub timing_digest: Vec<Option<TimingDigest>>,
/// information regarding the CONNECT proxy this connection uses.
pub proxy_digest: Option<Arc<ProxyDigest>>,
}
/// The interface to return protocol related information
pub trait ProtoDigest {
fn get_digest(&self) -> Option<&Digest> {
None
}
}
/// The timing information of the connection
#[derive(Clone, Debug)]
pub struct TimingDigest {
/// When this connection was established
pub established_ts: SystemTime,
}
impl Default for TimingDigest {
fn default() -> Self {
TimingDigest {
established_ts: SystemTime::UNIX_EPOCH,
}
}
}
/// The interface to return timing information
pub trait GetTimingDigest {
/// Return the timing for each layer from the lowest layer to upper
fn get_timing_digest(&self) -> Vec<Option<TimingDigest>>;
}
/// The interface to set or return proxy information
pub trait GetProxyDigest {
fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>>;
fn set_proxy_digest(&mut self, _digest: ProxyDigest) {}
}

View file

@ -0,0 +1,61 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bytes::{Bytes, BytesMut};
/// A buffer with size limit. When the total amount of data written to the buffer is below the limit
/// all the data will be held in the buffer. Otherwise, the buffer will report to be truncated.
pub(crate) struct FixedBuffer {
buffer: BytesMut,
capacity: usize,
truncated: bool,
}
impl FixedBuffer {
pub fn new(capacity: usize) -> Self {
FixedBuffer {
buffer: BytesMut::new(),
capacity,
truncated: false,
}
}
// TODO: maybe store a Vec of Bytes for zero-copy
pub fn write_to_buffer(&mut self, data: &Bytes) {
if !self.truncated && (self.buffer.len() + data.len() <= self.capacity) {
self.buffer.extend_from_slice(data);
} else {
// TODO: clear data because the data held here is useless anyway?
self.truncated = true;
}
}
pub fn clear(&mut self) {
self.truncated = false;
self.buffer.clear();
}
pub fn is_empty(&self) -> bool {
self.buffer.len() == 0
}
pub fn is_truncated(&self) -> bool {
self.truncated
}
pub fn get_buffer(&self) -> Option<Bytes> {
// TODO: return None if truncated?
if !self.is_empty() {
Some(self.buffer.clone().freeze())
} else {
None
}
}
}

View file

@ -0,0 +1,161 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bytes::Bytes;
use pingora_error::Result;
use pingora_http::{RequestHeader, ResponseHeader};
use std::time::Duration;
use super::v1::client::HttpSession as Http1Session;
use super::v2::client::Http2Session;
use crate::protocols::Digest;
/// A type for Http client session. It can be either an Http1 connection or an Http2 stream.
pub enum HttpSession {
H1(Http1Session),
H2(Http2Session),
}
impl HttpSession {
pub fn as_http1(&self) -> Option<&Http1Session> {
match self {
Self::H1(s) => Some(s),
Self::H2(_) => None,
}
}
pub fn as_http2(&self) -> Option<&Http2Session> {
match self {
Self::H1(_) => None,
Self::H2(s) => Some(s),
}
}
/// Write the request header to the server
/// After the request header is sent. The caller can either start reading the response or
/// sending request body if any.
pub async fn write_request_header(&mut self, req: Box<RequestHeader>) -> Result<()> {
match self {
HttpSession::H1(h1) => {
h1.write_request_header(req).await?;
Ok(())
}
HttpSession::H2(h2) => h2.write_request_header(req, false),
}
}
/// Write a chunk of the request body.
pub async fn write_request_body(&mut self, data: Bytes, end: bool) -> Result<()> {
match self {
HttpSession::H1(h1) => {
// TODO: maybe h1 should also have the concept of `end`
h1.write_body(&data).await?;
Ok(())
}
HttpSession::H2(h2) => h2.write_request_body(data, end),
}
}
/// Signal that the request body has ended
pub async fn finish_request_body(&mut self) -> Result<()> {
match self {
HttpSession::H1(h1) => {
h1.finish_body().await?;
Ok(())
}
HttpSession::H2(h2) => h2.finish_request_body(),
}
}
/// Set the read timeout for reading header and body.
///
/// The timeout is per read operation, not on the overall time reading the entire response
pub fn set_read_timeout(&mut self, timeout: Duration) {
match self {
HttpSession::H1(h1) => h1.read_timeout = Some(timeout),
HttpSession::H2(h2) => h2.read_timeout = Some(timeout),
}
}
/// Set the write timeout for writing header and body.
///
/// The timeout is per write operation, not on the overall time writing the entire request
pub fn set_write_timeout(&mut self, timeout: Duration) {
match self {
HttpSession::H1(h1) => h1.write_timeout = Some(timeout),
HttpSession::H2(_) => { /* no write timeout because the actual write happens async*/ }
}
}
/// Read the response header from the server
/// For http1, this function can be called multiple times, if the headers received are just
/// informational headers.
pub async fn read_response_header(&mut self) -> Result<()> {
match self {
HttpSession::H1(h1) => {
h1.read_response().await?;
Ok(())
}
HttpSession::H2(h2) => h2.read_response_header().await,
}
}
/// Read response body
///
/// `None` when no more body to read.
pub async fn read_response_body(&mut self) -> Result<Option<Bytes>> {
match self {
HttpSession::H1(h1) => h1.read_body_bytes().await,
HttpSession::H2(h2) => h2.read_response_body().await,
}
}
/// No (more) body to read
pub fn response_done(&mut self) -> bool {
match self {
HttpSession::H1(h1) => h1.is_body_done(),
HttpSession::H2(h2) => h2.response_finished(),
}
}
/// Give up the http session abruptly.
/// For H1 this will close the underlying connection
/// For H2 this will send RST_STREAM frame to end this stream if the stream has not ended at all
pub async fn shutdown(&mut self) {
match self {
Self::H1(s) => s.shutdown().await,
Self::H2(s) => s.shutdown(),
}
}
/// Get the response header of the server
///
/// `None` if the response header is not read yet.
pub fn response_header(&self) -> Option<&ResponseHeader> {
match self {
Self::H1(s) => s.resp_header(),
Self::H2(s) => s.response_header(),
}
}
/// Return the [Digest] of the connection
///
/// For reused connection, the timing in the digest will reflect its initial handshakes
/// The caller should check if the connection is reused to avoid misuse the timing field
pub fn digest(&self) -> Option<&Digest> {
match self {
Self::H1(s) => Some(s.digest()),
Self::H2(s) => s.digest(),
}
}
}

View file

@ -0,0 +1,161 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::Encode;
use super::COMPRESSION_ERROR;
use brotli::{CompressorWriter, DecompressorWriter};
use bytes::Bytes;
use pingora_error::{OrErr, Result};
use std::io::Write;
use std::time::{Duration, Instant};
pub struct Decompressor {
decompress: DecompressorWriter<Vec<u8>>,
total_in: usize,
total_out: usize,
duration: Duration,
}
impl Decompressor {
pub fn new() -> Self {
Decompressor {
// default buf is 4096 if 0 is used, TODO: figure out the significance of this value
decompress: DecompressorWriter::new(vec![], 0),
total_in: 0,
total_out: 0,
duration: Duration::new(0, 0),
}
}
}
impl Encode for Decompressor {
fn encode(&mut self, input: &[u8], end: bool) -> Result<Bytes> {
// reserve at most 16k
const MAX_INIT_COMPRESSED_SIZE_CAP: usize = 4 * 1024;
// Brotli compress ratio can be 3.5 to 4.5
const ESTIMATED_COMPRESSION_RATIO: usize = 4;
let start = Instant::now();
self.total_in += input.len();
// cap the buf size amplification, there is a DoS risk of always allocate
// 4x the memory of the input buffer
let reserve_size = if input.len() < MAX_INIT_COMPRESSED_SIZE_CAP {
input.len() * ESTIMATED_COMPRESSION_RATIO
} else {
input.len()
};
self.decompress.get_mut().reserve(reserve_size);
self.decompress
.write_all(input)
.or_err(COMPRESSION_ERROR, "while decompress Brotli")?;
// write to vec will never fail. The only possible error is that the input data
// is invalid (not brotli compressed)
if end {
self.decompress
.flush()
.or_err(COMPRESSION_ERROR, "while decompress Brotli")?;
}
self.total_out += self.decompress.get_ref().len();
self.duration += start.elapsed();
Ok(std::mem::take(self.decompress.get_mut()).into()) // into() Bytes will drop excess capacity
}
fn stat(&self) -> (&'static str, usize, usize, Duration) {
("de-brotli", self.total_in, self.total_out, self.duration)
}
}
pub struct Compressor {
compress: CompressorWriter<Vec<u8>>,
total_in: usize,
total_out: usize,
duration: Duration,
}
impl Compressor {
pub fn new(level: u32) -> Self {
Compressor {
// buf_size:4096 , lgwin:19 TODO: fine tune these
compress: CompressorWriter::new(vec![], 4096, level, 19),
total_in: 0,
total_out: 0,
duration: Duration::new(0, 0),
}
}
}
impl Encode for Compressor {
fn encode(&mut self, input: &[u8], end: bool) -> Result<Bytes> {
// reserve at most 16k
const MAX_INIT_COMPRESSED_BUF_SIZE: usize = 16 * 1024;
let start = Instant::now();
self.total_in += input.len();
// reserve at most input size, cap at 16k, compressed output should be smaller
self.compress
.get_mut()
.reserve(std::cmp::min(MAX_INIT_COMPRESSED_BUF_SIZE, input.len()));
self.compress
.write_all(input)
.or_err(COMPRESSION_ERROR, "while compress Brotli")?;
// write to vec will never fail.
if end {
self.compress
.flush()
.or_err(COMPRESSION_ERROR, "while compress Brotli")?;
}
self.total_out += self.compress.get_ref().len();
self.duration += start.elapsed();
Ok(std::mem::take(self.compress.get_mut()).into()) // into() Bytes will drop excess capacity
}
fn stat(&self) -> (&'static str, usize, usize, Duration) {
("brotli", self.total_in, self.total_out, self.duration)
}
}
#[cfg(test)]
mod tests_stream {
use super::*;
#[test]
fn decompress_brotli_data() {
let mut compressor = Decompressor::new();
let decompressed = compressor
.encode(
&[
0x1f, 0x0f, 0x00, 0xf8, 0x45, 0x07, 0x87, 0x3e, 0x10, 0xfb, 0x55, 0x92, 0xec,
0x12, 0x09, 0xcc, 0x38, 0xdd, 0x51, 0x1e,
],
true,
)
.unwrap();
assert_eq!(&decompressed[..], &b"adcdefgabcdefgh\n"[..]);
}
#[test]
fn compress_brotli_data() {
let mut compressor = Compressor::new(11);
let compressed = compressor.encode(&b"adcdefgabcdefgh\n"[..], true).unwrap();
assert_eq!(
&compressed[..],
&[
0x85, 0x07, 0x00, 0xf8, 0x45, 0x07, 0x87, 0x3e, 0x10, 0xfb, 0x55, 0x92, 0xec, 0x12,
0x09, 0xcc, 0x38, 0xdd, 0x51, 0x1e,
],
);
}
}

View file

@ -0,0 +1,103 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::Encode;
use bytes::Bytes;
use flate2::write::GzEncoder;
use pingora_error::Result;
use std::io::Write;
use std::time::{Duration, Instant};
// TODO: unzip
pub struct Compressor {
// TODO: enum for other compression algorithms
compress: GzEncoder<Vec<u8>>,
total_in: usize,
total_out: usize,
duration: Duration,
}
impl Compressor {
pub fn new(level: u32) -> Compressor {
Compressor {
compress: GzEncoder::new(vec![], flate2::Compression::new(level)),
total_in: 0,
total_out: 0,
duration: Duration::new(0, 0),
}
}
}
impl Encode for Compressor {
// infallible because compression can take any data
fn encode(&mut self, input: &[u8], end: bool) -> Result<Bytes> {
// reserve at most 16k
const MAX_INIT_COMPRESSED_BUF_SIZE: usize = 16 * 1024;
let start = Instant::now();
self.total_in += input.len();
self.compress
.get_mut()
.reserve(std::cmp::min(MAX_INIT_COMPRESSED_BUF_SIZE, input.len()));
self.write_all(input).unwrap(); // write to vec, should never fail
if end {
self.try_finish().unwrap(); // write to vec, should never fail
}
self.total_out += self.compress.get_ref().len();
self.duration += start.elapsed();
Ok(std::mem::take(self.compress.get_mut()).into()) // into() Bytes will drop excess capacity
}
fn stat(&self) -> (&'static str, usize, usize, Duration) {
("gzip", self.total_in, self.total_out, self.duration)
}
}
use std::ops::{Deref, DerefMut};
impl Deref for Compressor {
type Target = GzEncoder<Vec<u8>>;
fn deref(&self) -> &Self::Target {
&self.compress
}
}
impl DerefMut for Compressor {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.compress
}
}
#[cfg(test)]
mod tests_stream {
use super::*;
#[test]
fn gzip_data() {
let mut compressor = Compressor::new(6);
let compressed = compressor.encode(b"abcdefg", true).unwrap();
// gzip magic headers
assert_eq!(&compressed[..3], &[0x1f, 0x8b, 0x08]);
// check the crc32 footer
assert_eq!(
&compressed[compressed.len() - 9..],
&[0, 166, 106, 42, 49, 7, 0, 0, 0]
);
assert_eq!(compressor.total_in, 7);
assert_eq!(compressor.total_out, compressed.len());
assert!(compressor.get_ref().is_empty());
}
}

View file

@ -0,0 +1,612 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HTTP response (de)compression libraries
//!
//! Brotli and Gzip and partially supported.
use super::HttpTask;
use bytes::Bytes;
use log::warn;
use pingora_error::{ErrorType, Result};
use pingora_http::{RequestHeader, ResponseHeader};
use std::time::Duration;
mod brotli;
mod gzip;
mod zstd;
/// The type of error to return when (de)compression fails
pub const COMPRESSION_ERROR: ErrorType = ErrorType::new("CompressionError");
/// The trait for both compress and decompress because the interface and syntax are the same:
/// encode some bytes to other bytes
pub trait Encode {
/// Encode the input bytes. The `end` flag signals the end of the entire input. The `end` flag
/// helps the encoder to flush out the remaining buffered encoded data because certain compression
/// algorithms prefer to collect large enough data to compress all together.
fn encode(&mut self, input: &[u8], end: bool) -> Result<Bytes>;
/// Return the Encoder's name, the total input bytes, the total output bytes and the total
/// duration spent on encoding the data.
fn stat(&self) -> (&'static str, usize, usize, Duration);
}
/// The response compression object. Currently support gzip compression and brotli decompression.
///
/// To use it, the caller should create a [`ResponseCompressionCtx`] per HTTP session.
/// The caller should call the corresponding filters for the request header, response header and
/// response body. If the algorithms are supported, the output response body will be encoded.
/// The response header will be adjusted accordingly as well. If the algorithm is not supported
/// or no encoding needed, the response is untouched.
///
/// If configured and if the request's `accept-encoding` header contains the algorithm supported and the
/// incoming response doesn't have that encoding, the filter will compress the response.
/// If configured and supported, and if the incoming response's `content-encoding` isn't one of the
/// request's `accept-encoding` supported algorithm, the ctx will decompress the response.
///
/// # Currently supported algorithms and actions
/// - Brotli decompression: if the response is br compressed, this ctx can decompress it
/// - Gzip compression: if the response is uncompressed, this ctx can compress it with gzip
pub struct ResponseCompressionCtx(CtxInner);
enum CtxInner {
HeaderPhase {
compression_level: u32,
decompress_enable: bool,
// Store the preferred list to compare with content-encoding
accept_encoding: Vec<Algorithm>,
},
BodyPhase(Option<Box<dyn Encode + Send + Sync>>),
}
impl ResponseCompressionCtx {
/// Create a new [`ResponseCompressionCtx`] with the expected compression level. `0` will disable
/// the compression.
/// The `decompress_enable` flag will tell the ctx to decompress if needed.
pub fn new(compression_level: u32, decompress_enable: bool) -> Self {
Self(CtxInner::HeaderPhase {
compression_level,
decompress_enable,
accept_encoding: Vec::new(),
})
}
/// Whether the encoder is enabled.
/// The enablement will change according to the request and response filter by this ctx.
pub fn is_enabled(&self) -> bool {
match &self.0 {
CtxInner::HeaderPhase {
compression_level,
decompress_enable,
accept_encoding: _,
} => *compression_level != 0 || *decompress_enable,
CtxInner::BodyPhase(c) => c.is_some(),
}
}
/// Return the stat of this ctx:
/// algorithm name, in bytes, out bytes, time took for the compression
pub fn get_info(&self) -> Option<(&'static str, usize, usize, Duration)> {
match &self.0 {
CtxInner::HeaderPhase {
compression_level: _,
decompress_enable: _,
accept_encoding: _,
} => None,
CtxInner::BodyPhase(c) => c.as_ref().map(|c| c.stat()),
}
}
/// Adjust the compression level.
/// # Panic
/// This function will panic if it has already started encoding the response body.
pub fn adjust_level(&mut self, new_level: u32) {
match &mut self.0 {
CtxInner::HeaderPhase {
compression_level,
decompress_enable: _,
accept_encoding: _,
} => {
*compression_level = new_level;
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
/// Adjust the decompression flag.
/// # Panic
/// This function will panic if it has already started encoding the response body.
pub fn adjust_decompression(&mut self, enabled: bool) {
match &mut self.0 {
CtxInner::HeaderPhase {
compression_level: _,
decompress_enable,
accept_encoding: _,
} => {
*decompress_enable = enabled;
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
/// Feed the request header into this ctx.
pub fn request_filter(&mut self, req: &RequestHeader) {
if !self.is_enabled() {
return;
}
match &mut self.0 {
CtxInner::HeaderPhase {
compression_level: _,
decompress_enable: _,
accept_encoding,
} => parse_accept_encoding(
req.headers.get(http::header::ACCEPT_ENCODING),
accept_encoding,
),
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
fn response_header_filter(&mut self, resp: &mut ResponseHeader, end: bool) {
match &self.0 {
CtxInner::HeaderPhase {
compression_level,
decompress_enable,
accept_encoding,
} => {
if resp.status.is_informational() {
if resp.status == http::status::StatusCode::SWITCHING_PROTOCOLS {
// no transformation for websocket (TODO: cite RFC)
self.0 = CtxInner::BodyPhase(None);
}
// else, wait for the final response header for decision
return;
}
// do nothing if no body
if end {
self.0 = CtxInner::BodyPhase(None);
return;
}
let action = decide_action(resp, accept_encoding);
let encoder = match action {
Action::Noop => None,
Action::Compress(algorithm) => algorithm.compressor(*compression_level),
Action::Decompress(algorithm) => algorithm.decompressor(*decompress_enable),
};
if encoder.is_some() {
adjust_response_header(resp, &action);
}
self.0 = CtxInner::BodyPhase(encoder);
}
CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"),
}
}
fn response_body_filter(&mut self, data: Option<&Bytes>, end: bool) -> Option<Bytes> {
match &mut self.0 {
CtxInner::HeaderPhase {
compression_level: _,
decompress_enable: _,
accept_encoding: _,
} => panic!("Wrong phase: HeaderPhase"),
CtxInner::BodyPhase(compressor) => {
let result = compressor
.as_mut()
.map(|c| {
// Feed even empty slice to compressor because it might yield data
// when `end` is true
let data = if let Some(b) = data { b.as_ref() } else { &[] };
c.encode(data, end)
})
.transpose();
result.unwrap_or_else(|e| {
warn!("Failed to compress, compression disabled, {}", e);
// no point to transcode further data because bad data is already seen
self.0 = CtxInner::BodyPhase(None);
None
})
}
}
}
/// Feed the response into this ctx.
/// This filter will mutate the response accordingly if encoding is needed.
pub fn response_filter(&mut self, t: &mut HttpTask) {
if !self.is_enabled() {
return;
}
match t {
HttpTask::Header(resp, end) => self.response_header_filter(resp, *end),
HttpTask::Body(data, end) => {
let compressed = self.response_body_filter(data.as_ref(), *end);
if compressed.is_some() {
*t = HttpTask::Body(compressed, *end);
}
}
HttpTask::Done => {
// try to finish/flush compression
let compressed = self.response_body_filter(None, true);
if compressed.is_some() {
// compressor has more data to flush
*t = HttpTask::Body(compressed, true);
}
}
_ => { /* Trailer, Failed: do nothing? */ }
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum Algorithm {
Any, // the "*"
Gzip,
Brotli,
Zstd,
// TODO: Identify,
// TODO: Deflate
Other, // anyting unknown
}
impl Algorithm {
pub fn as_str(&self) -> &'static str {
match self {
Algorithm::Gzip => "gzip",
Algorithm::Brotli => "br",
Algorithm::Zstd => "zstd",
Algorithm::Any => "*",
Algorithm::Other => "other",
}
}
pub fn compressor(&self, level: u32) -> Option<Box<dyn Encode + Send + Sync>> {
if level == 0 {
None
} else {
match self {
Self::Gzip => Some(Box::new(gzip::Compressor::new(level))),
Self::Brotli => Some(Box::new(brotli::Compressor::new(level))),
Self::Zstd => Some(Box::new(zstd::Compressor::new(level))),
_ => None, // not implemented
}
}
}
pub fn decompressor(&self, enabled: bool) -> Option<Box<dyn Encode + Send + Sync>> {
if !enabled {
None
} else {
match self {
Self::Brotli => Some(Box::new(brotli::Decompressor::new())),
_ => None, // not implemented
}
}
}
}
impl From<&str> for Algorithm {
fn from(s: &str) -> Self {
use unicase::UniCase;
let coding = UniCase::new(s);
if coding == UniCase::ascii("gzip") {
Algorithm::Gzip
} else if coding == UniCase::ascii("br") {
Algorithm::Brotli
} else if coding == UniCase::ascii("zstd") {
Algorithm::Zstd
} else if s.is_empty() {
Algorithm::Any
} else {
Algorithm::Other
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum Action {
Noop, // do nothing, e.g. when the input is already gzip
Compress(Algorithm),
Decompress(Algorithm),
}
// parse Accpet-Encoding header and put it to the list
fn parse_accept_encoding(accept_encoding: Option<&http::HeaderValue>, list: &mut Vec<Algorithm>) {
// https://www.rfc-editor.org/rfc/rfc9110#name-accept-encoding
if let Some(ac) = accept_encoding {
// fast path
if ac.as_bytes() == b"gzip" {
list.push(Algorithm::Gzip);
return;
}
// properly parse AC header
match sfv::Parser::parse_list(ac.as_bytes()) {
Ok(parsed) => {
for item in parsed {
if let sfv::ListEntry::Item(i) = item {
if let Some(s) = i.bare_item.as_token() {
// TODO: support q value
let algorithm = Algorithm::from(s);
// ignore algorithms that we don't understand ingore
if algorithm != Algorithm::Other {
list.push(Algorithm::from(s));
}
}
}
}
}
Err(e) => {
warn!("Failed to parse accept-encoding {ac:?}, {e}")
}
}
} else {
// "If no Accept-Encoding header, any content coding is acceptable"
// keep the list empty
}
}
#[test]
fn test_accept_encoding_req_header() {
let mut header = RequestHeader::build("GET", b"/", None).unwrap();
let mut ac_list = Vec::new();
parse_accept_encoding(
header.headers.get(http::header::ACCEPT_ENCODING),
&mut ac_list,
);
assert!(ac_list.is_empty());
let mut ac_list = Vec::new();
header.insert_header("accept-encoding", "gzip").unwrap();
parse_accept_encoding(
header.headers.get(http::header::ACCEPT_ENCODING),
&mut ac_list,
);
assert_eq!(ac_list[0], Algorithm::Gzip);
let mut ac_list = Vec::new();
header
.insert_header("accept-encoding", "what, br, gzip")
.unwrap();
parse_accept_encoding(
header.headers.get(http::header::ACCEPT_ENCODING),
&mut ac_list,
);
assert_eq!(ac_list[0], Algorithm::Brotli);
assert_eq!(ac_list[1], Algorithm::Gzip);
}
// filter response header to see if (de)compression is needed
fn decide_action(resp: &ResponseHeader, accept_encoding: &[Algorithm]) -> Action {
use http::header::CONTENT_ENCODING;
let content_encoding = if let Some(ce) = resp.headers.get(CONTENT_ENCODING) {
// https://www.rfc-editor.org/rfc/rfc9110#name-content-encoding
if let Ok(ce_str) = std::str::from_utf8(ce.as_bytes()) {
Some(Algorithm::from(ce_str))
} else {
// not utf-8, treat it as unknown encoding to leave it untouched
Some(Algorithm::Other)
}
} else {
// no Accpet-encoding
None
};
if let Some(ce) = content_encoding {
if accept_encoding.contains(&ce) {
// downstream can accept this encoding, nothing to do
Action::Noop
} else {
// always decompress because uncompressed is always acceptable
// https://www.rfc-editor.org/rfc/rfc9110#field.accept-encoding
// "If the representation has no content coding, then it is acceptable by default
// unless specifically excluded..." TODO: check the exclude case
// TODO: we could also transcode it to a preferred encoding, e.g. br->gzip
Action::Decompress(ce)
}
} else if accept_encoding.is_empty() // both CE and AE are empty
|| !compressible(resp) // the type is not compressible
|| accept_encoding[0] == Algorithm::Any
{
Action::Noop
} else {
// try to compress with the first AC
// TODO: support to configure preferred encoding
Action::Compress(accept_encoding[0])
}
}
#[test]
fn test_decide_action() {
use Action::*;
use Algorithm::*;
let header = ResponseHeader::build(200, None).unwrap();
// no compression asked, no compression needed
assert_eq!(decide_action(&header, &[]), Noop);
// already gzip, no compression needed
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-type", "text/html").unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Noop);
// already gzip, no compression needed, upper case
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-encoding", "GzIp").unwrap();
header.insert_header("content-type", "text/html").unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Noop);
// no encoding, compression needed, accepted content-type, large enough
// Will compress
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
header.insert_header("content-type", "text/html").unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Compress(Gzip));
// too small
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "19").unwrap();
header.insert_header("content-type", "text/html").unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Noop);
// already compressed MIME
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
header
.insert_header("content-type", "text/html+zip")
.unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Noop);
// unsupported MIME
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
header.insert_header("content-type", "image/jpg").unwrap();
assert_eq!(decide_action(&header, &[Gzip]), Noop);
// compressed, need decompress
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
assert_eq!(decide_action(&header, &[]), Decompress(Gzip));
// accept-encoding different, need decompress
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
assert_eq!(decide_action(&header, &[Brotli]), Decompress(Gzip));
// less preferred but no need to decompress
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
assert_eq!(decide_action(&header, &[Brotli, Gzip]), Noop);
}
use once_cell::sync::Lazy;
use regex::Regex;
// Allow text, application, font, a few image/ MIME types and binary/octet-stream
// TODO: fine tune this list
static MIME_CHECK: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"^(?:text/|application/|font/|image/(?:x-icon|svg\+xml|nd\.microsoft\.icon)|binary/octet-stream)")
.unwrap()
});
// check if the response mime type is compressible
fn compressible(resp: &ResponseHeader) -> bool {
// arbitrary size limit, things to consider
// 1. too short body may have little redundancy to compress
// 2. gzip header and footer overhead
// 3. latency is the same as long as data fits in a TCP congestion window regardless of size
const MIN_COMPRESS_LEN: usize = 20;
// check if response is too small to compress
if let Some(cl) = resp.headers.get(http::header::CONTENT_LENGTH) {
if let Some(cl_num) = std::str::from_utf8(cl.as_bytes())
.ok()
.and_then(|v| v.parse::<usize>().ok())
{
if cl_num < MIN_COMPRESS_LEN {
return false;
}
}
}
// no Content-Length or large enough, check content-type next
if let Some(ct) = resp.headers.get(http::header::CONTENT_TYPE) {
if let Ok(ct_str) = std::str::from_utf8(ct.as_bytes()) {
if ct_str.contains("zip") {
// heuristic: don't compress mime type that has zip in it
false
} else {
// check if mime type in allow list
MIME_CHECK.find(ct_str).is_some()
}
} else {
false // invalid CT header, don't compress
}
} else {
false // don't compress empty content-type
}
}
fn adjust_response_header(resp: &mut ResponseHeader, action: &Action) {
use http::header::{HeaderValue, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
fn set_stream_headers(resp: &mut ResponseHeader) {
// because the transcoding is streamed, content length is not known ahead
resp.remove_header(&CONTENT_LENGTH);
// we stream body now TODO: chunked is for h1 only
resp.insert_header(&TRANSFER_ENCODING, HeaderValue::from_static("chunked"))
.unwrap();
}
match action {
Action::Noop => { /* do nothing */ }
Action::Decompress(_) => {
resp.remove_header(&CONTENT_ENCODING);
set_stream_headers(resp)
}
Action::Compress(a) => {
resp.insert_header(&CONTENT_ENCODING, HeaderValue::from_static(a.as_str()))
.unwrap();
set_stream_headers(resp)
}
}
}
#[test]
fn test_adjust_response_header() {
use Action::*;
use Algorithm::*;
// noop
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
adjust_response_header(&mut header, &Noop);
assert_eq!(
header.headers.get("content-encoding").unwrap().as_bytes(),
b"gzip"
);
assert_eq!(
header.headers.get("content-length").unwrap().as_bytes(),
b"20"
);
assert!(header.headers.get("transfor-encoding").is_none());
// decompress gzip
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
header.insert_header("content-encoding", "gzip").unwrap();
adjust_response_header(&mut header, &Decompress(Gzip));
assert!(header.headers.get("content-encoding").is_none());
assert!(header.headers.get("content-length").is_none());
assert_eq!(
header.headers.get("transfer-encoding").unwrap().as_bytes(),
b"chunked"
);
// compress
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header("content-length", "20").unwrap();
adjust_response_header(&mut header, &Compress(Gzip));
assert_eq!(
header.headers.get("content-encoding").unwrap().as_bytes(),
b"gzip"
);
assert!(header.headers.get("content-length").is_none());
assert_eq!(
header.headers.get("transfer-encoding").unwrap().as_bytes(),
b"chunked"
);
}

View file

@ -0,0 +1,91 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{Encode, COMPRESSION_ERROR};
use bytes::Bytes;
use parking_lot::Mutex;
use pingora_error::{OrErr, Result};
use std::io::Write;
use std::time::{Duration, Instant};
use zstd::stream::write::Encoder;
pub struct Compressor {
compress: Mutex<Encoder<'static, Vec<u8>>>,
total_in: usize,
total_out: usize,
duration: Duration,
}
impl Compressor {
pub fn new(level: u32) -> Self {
Compressor {
// Mutex because Encoder is not Sync
// https://github.com/gyscos/zstd-rs/issues/186
compress: Mutex::new(Encoder::new(vec![], level as i32).unwrap()),
total_in: 0,
total_out: 0,
duration: Duration::new(0, 0),
}
}
}
impl Encode for Compressor {
fn encode(&mut self, input: &[u8], end: bool) -> Result<Bytes> {
// reserve at most 16k
const MAX_INIT_COMPRESSED_BUF_SIZE: usize = 16 * 1024;
let start = Instant::now();
self.total_in += input.len();
let mut compress = self.compress.lock();
// reserve at most input size, cap at 16k, compressed output should be smaller
compress
.get_mut()
.reserve(std::cmp::min(MAX_INIT_COMPRESSED_BUF_SIZE, input.len()));
compress
.write_all(input)
.or_err(COMPRESSION_ERROR, "while compress zstd")?;
// write to vec will never fail.
if end {
compress
.do_finish()
.or_err(COMPRESSION_ERROR, "while compress zstd")?;
}
self.total_out += compress.get_ref().len();
self.duration += start.elapsed();
Ok(std::mem::take(compress.get_mut()).into()) // into() Bytes will drop excess capacity
}
fn stat(&self) -> (&'static str, usize, usize, Duration) {
("zstd", self.total_in, self.total_out, self.duration)
}
}
#[cfg(test)]
mod tests_stream {
use super::*;
#[test]
fn compress_zstd_data() {
let mut compressor = Compressor::new(11);
let input = b"adcdefgabcdefghadcdefgabcdefghadcdefgabcdefghadcdefgabcdefgh\n";
let compressed = compressor.encode(&input[..], false).unwrap();
// waiting for more data
assert!(compressed.is_empty());
let compressed = compressor.encode(&input[..], true).unwrap();
// the zstd Magic_Number
assert_eq!(&compressed[..4], &[0x28, 0xB5, 0x2F, 0xFD]);
assert!(compressed.len() < input.len());
}
}

View file

@ -0,0 +1,90 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::NaiveDateTime;
use http::header::HeaderValue;
use std::cell::RefCell;
use std::time::{Duration, SystemTime};
fn to_date_string(epoch_sec: i64) -> String {
let dt = NaiveDateTime::from_timestamp_opt(epoch_sec, 0).unwrap();
dt.format("%a, %d %b %Y %H:%M:%S GMT").to_string()
}
struct CacheableDate {
h1_date: HeaderValue,
epoch: Duration,
}
impl CacheableDate {
pub fn new() -> Self {
let d = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap();
CacheableDate {
h1_date: HeaderValue::from_str(&to_date_string(d.as_secs() as i64)).unwrap(),
epoch: d,
}
}
pub fn update(&mut self, d_now: Duration) {
if d_now.as_secs() != self.epoch.as_secs() {
self.epoch = d_now;
self.h1_date = HeaderValue::from_str(&to_date_string(d_now.as_secs() as i64)).unwrap();
}
}
pub fn get_date(&mut self) -> HeaderValue {
let d = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap();
self.update(d);
self.h1_date.clone()
}
}
thread_local! {
static CACHED_DATE: RefCell<CacheableDate>
= RefCell::new(CacheableDate::new());
}
pub fn get_cached_date() -> HeaderValue {
CACHED_DATE.with(|cache_date| (*cache_date.borrow_mut()).get_date())
}
#[cfg(test)]
mod test {
use super::*;
fn now_date_header() -> HeaderValue {
HeaderValue::from_str(&to_date_string(
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
))
.unwrap()
}
#[test]
fn test_date_string() {
let date_str = to_date_string(1);
assert_eq!("Thu, 01 Jan 1970 00:00:01 GMT", date_str);
}
#[test]
fn test_date_cached() {
assert_eq!(get_cached_date(), now_date_header());
}
}

View file

@ -0,0 +1,41 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Error response generating utilities.
use http::header;
use once_cell::sync::Lazy;
use pingora_http::ResponseHeader;
use super::SERVER_NAME;
/// Generate an error response with the given status code.
///
/// This error response has a zero `Content-Length` and `Cache-Control: private, no-store`.
pub fn gen_error_response(code: u16) -> ResponseHeader {
let mut resp = ResponseHeader::build(code, Some(4)).unwrap();
resp.insert_header(header::SERVER, &SERVER_NAME[..])
.unwrap();
resp.insert_header(header::DATE, "Sun, 06 Nov 1994 08:49:37 GMT")
.unwrap(); // placeholder
resp.insert_header(header::CONTENT_LENGTH, "0").unwrap();
resp.insert_header(header::CACHE_CONTROL, "private, no-store")
.unwrap();
resp
}
/// Pre-generated 502 response
pub static HTTP_502_RESPONSE: Lazy<ResponseHeader> = Lazy::new(|| gen_error_response(502));
/// Pre-generated 400 response
pub static HTTP_400_RESPONSE: Lazy<ResponseHeader> = Lazy::new(|| gen_error_response(400));

View file

@ -0,0 +1,57 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HTTP/1.x and HTTP/2 implementation APIs
mod body_buffer;
pub mod client;
pub mod compression;
pub(crate) mod date;
pub mod error_resp;
pub mod server;
pub mod v1;
pub mod v2;
pub use server::Session as ServerSession;
/// The Pingora server name string
pub const SERVER_NAME: &[u8; 7] = b"Pingora";
/// An enum to hold all possible HTTP response events.
#[derive(Debug)]
pub enum HttpTask {
/// the response header and the boolean end of response flag
Header(Box<pingora_http::ResponseHeader>, bool),
/// A piece of response header and the end of response boolean flag
Body(Option<bytes::Bytes>, bool),
/// HTTP response trailer
Trailer(Option<Box<http::HeaderMap>>),
/// Signal that the response is already finished
Done,
/// Signal that the reading of the response encounters errors.
Failed(pingora_error::BError),
}
impl HttpTask {
/// Whether this [`HttpTask`] means the end of the response
pub fn is_end(&self) -> bool {
match self {
HttpTask::Header(_, end) => *end,
HttpTask::Body(_, end) => *end,
HttpTask::Trailer(_) => true,
HttpTask::Done => true,
HttpTask::Failed(_) => true,
}
}
}

View file

@ -0,0 +1,333 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HTTP server session APIs
use super::error_resp;
use super::v1::server::HttpSession as SessionV1;
use super::v2::server::HttpSession as SessionV2;
use super::HttpTask;
use crate::protocols::Stream;
use bytes::Bytes;
use http::header::AsHeaderName;
use http::HeaderValue;
use log::error;
use pingora_error::Result;
use pingora_http::{RequestHeader, ResponseHeader};
/// HTTP server session object for both HTTP/1.x and HTTP/2
pub enum Session {
H1(SessionV1),
H2(SessionV2),
}
impl Session {
/// Create a new [`Session`] from an established connection for HTTP/1.x
pub fn new_http1(stream: Stream) -> Self {
Self::H1(SessionV1::new(stream))
}
/// Create a new [`Session`] from an established HTTP/2 stream
pub fn new_http2(session: SessionV2) -> Self {
Self::H2(session)
}
/// Whether the session is HTTP/2. If not it is HTTP/1.x
pub fn is_http2(&self) -> bool {
matches!(self, Self::H2(_))
}
/// Read the request header. This method is required to be called first before doing anything
/// else with the session.
/// - `Ok(true)`: successful
/// - `Ok(false)`: client exit without sending any bytes. This is normal on reused connection.
/// In this case the user should give up this session.
pub async fn read_request(&mut self) -> Result<bool> {
match self {
Self::H1(s) => {
let read = s.read_request().await?;
Ok(read.is_some())
}
// This call will always return `Ok(true)` for Http2 because the request is already read
Self::H2(_) => Ok(true),
}
}
/// Return the request header it just read.
/// # Panic
/// This function will panic if [`Self::read_request()`] is not called.
pub fn req_header(&self) -> &RequestHeader {
match self {
Self::H1(s) => s.req_header(),
Self::H2(s) => s.req_header(),
}
}
/// Return a mutable reference to request header it just read.
/// # Panic
/// This function will panic if [`Self::read_request()`] is not called.
pub fn req_header_mut(&mut self) -> &mut RequestHeader {
match self {
Self::H1(s) => s.req_header_mut(),
Self::H2(s) => s.req_header_mut(),
}
}
/// Return the header by name. None if the header doesn't exist.
///
/// In case there are multiple headers under the same name, the first one will be returned. To
/// get all the headers: use `self.req_header().headers.get_all()`.
pub fn get_header<K: AsHeaderName>(&self, key: K) -> Option<&HeaderValue> {
self.req_header().headers.get(key)
}
/// Get the header value in its raw format.
/// If the header doesn't exist, return an empty slice.
pub fn get_header_bytes<K: AsHeaderName>(&self, key: K) -> &[u8] {
self.get_header(key).map_or(b"", |v| v.as_bytes())
}
/// Read the request body. Ok(None) if no (more) body to read
pub async fn read_request_body(&mut self) -> Result<Option<Bytes>> {
match self {
Self::H1(s) => s.read_body_bytes().await,
Self::H2(s) => s.read_body_bytes().await,
}
}
/// Write the response header to client
/// Informational headers (status code 100-199, excluding 101) can be written multiple times the final
/// response header (status code 200+ or 101) is written.
pub async fn write_response_header(&mut self, resp: Box<ResponseHeader>) -> Result<()> {
match self {
Self::H1(s) => {
s.write_response_header(resp).await?;
Ok(())
}
Self::H2(s) => s.write_response_header(resp, false),
}
}
/// Similar to `write_response_header()`, this fn will clone the `resp` internally
pub async fn write_response_header_ref(&mut self, resp: &ResponseHeader) -> Result<()> {
match self {
Self::H1(s) => {
s.write_response_header_ref(resp).await?;
Ok(())
}
Self::H2(s) => s.write_response_header_ref(resp, false),
}
}
/// Write the response body to client
pub async fn write_response_body(&mut self, data: Bytes) -> Result<()> {
match self {
Self::H1(s) => {
s.write_body(&data).await?;
Ok(())
}
Self::H2(s) => s.write_body(data, false),
}
}
/// Finish the life of this request.
/// For H1, if connection reuse is supported, a Some(Stream) will be returned, otherwise None.
/// For H2, always return None because H2 stream is not reusable.
pub async fn finish(self) -> Result<Option<Stream>> {
match self {
Self::H1(mut s) => {
// need to flush body due to buffering
s.finish_body().await?;
Ok(s.reuse().await)
}
Self::H2(mut s) => {
s.finish()?;
Ok(None)
}
}
}
pub async fn response_duplex_vec(&mut self, tasks: Vec<HttpTask>) -> Result<bool> {
match self {
Self::H1(s) => s.response_duplex_vec(tasks).await,
Self::H2(s) => s.response_duplex_vec(tasks),
}
}
/// Set connection reuse. `duration` defines how long the connection is kept open for the next
/// request to reuse. Noop for h2
pub fn set_keepalive(&mut self, duration: Option<u64>) {
match self {
Self::H1(s) => s.set_server_keepalive(duration),
Self::H2(_) => {}
}
}
/// Return a digest of the request including the method, path and Host header
// TODO: make this use a `Formatter`
pub fn request_summary(&self) -> String {
match self {
Self::H1(s) => s.request_summary(),
Self::H2(s) => s.request_summary(),
}
}
/// Return the written response header. `None` if it is not written yet.
/// Only the final (status code >= 200 or 101) response header will be returned
pub fn response_written(&self) -> Option<&ResponseHeader> {
match self {
Self::H1(s) => s.response_written(),
Self::H2(s) => s.response_written(),
}
}
/// Give up the http session abruptly.
/// For H1 this will close the underlying connection
/// For H2 this will send RESET frame to end this stream without impacting the connection
pub async fn shutdown(&mut self) {
match self {
Self::H1(s) => s.shutdown().await,
Self::H2(s) => s.shutdown(),
}
}
pub fn to_h1_raw(&self) -> Bytes {
match self {
Self::H1(s) => s.get_headers_raw_bytes(),
Self::H2(s) => s.pseudo_raw_h1_request_header(),
}
}
/// Whether the whole request body is sent
pub fn is_body_done(&mut self) -> bool {
match self {
Self::H1(s) => s.is_body_done(),
Self::H2(s) => s.is_body_done(),
}
}
/// Notify the client that the entire body is sent
/// for H1 chunked encoding, this will end the last empty chunk
/// for H1 content-length, this has no effect.
/// for H2, this will send an empty DATA frame with END_STREAM flag
pub async fn finish_body(&mut self) -> Result<()> {
match self {
Self::H1(s) => s.finish_body().await.map(|_| ()),
Self::H2(s) => s.finish(),
}
}
/// Send error response to client
pub async fn respond_error(&mut self, error: u16) {
let resp = match error {
/* commmon error responses are pre-generated */
502 => error_resp::HTTP_502_RESPONSE.clone(),
400 => error_resp::HTTP_400_RESPONSE.clone(),
_ => error_resp::gen_error_response(error),
};
// TODO: we shouldn't be closing downstream connections on internally generated errors
// and possibly other upstream connect() errors (connection refused, timeout, etc)
//
// This change is only here because we DO NOT re-use downstream connections
// today on these errors and we should signal to the client that pingora is dropping it
// rather than a misleading the client with 'keep-alive'
self.set_keepalive(None);
self.write_response_header(Box::new(resp))
.await
.unwrap_or_else(|e| {
error!("failed to send error response to downstream: {e}");
});
}
/// Whether there is no request body
pub fn is_body_empty(&mut self) -> bool {
match self {
Self::H1(s) => s.is_body_empty(),
Self::H2(s) => s.is_body_empty(),
}
}
pub fn retry_buffer_truncated(&self) -> bool {
match self {
Self::H1(s) => s.retry_buffer_truncated(),
Self::H2(s) => s.retry_buffer_truncated(),
}
}
pub fn enable_retry_buffering(&mut self) {
match self {
Self::H1(s) => s.enable_retry_buffering(),
Self::H2(s) => s.enable_retry_buffering(),
}
}
pub fn get_retry_buffer(&self) -> Option<Bytes> {
match self {
Self::H1(s) => s.get_retry_buffer(),
Self::H2(s) => s.get_retry_buffer(),
}
}
/// Read body (same as `read_request_body()`) or pending forever until downstream
/// terminates the session.
pub async fn read_body_or_idle(&mut self, no_body_expected: bool) -> Result<Option<Bytes>> {
match self {
Self::H1(s) => s.read_body_or_idle(no_body_expected).await,
Self::H2(s) => s.read_body_or_idle(no_body_expected).await,
}
}
pub fn as_http1(&self) -> Option<&SessionV1> {
match self {
Self::H1(s) => Some(s),
Self::H2(_) => None,
}
}
pub fn as_http2(&self) -> Option<&SessionV2> {
match self {
Self::H1(_) => None,
Self::H2(s) => Some(s),
}
}
/// Write a 100 Continue response to the client.
pub async fn write_continue_response(&mut self) -> Result<()> {
match self {
Self::H1(s) => s.write_continue_response().await,
Self::H2(s) => s.write_response_header(
Box::new(ResponseHeader::build(100, Some(0)).unwrap()),
false,
),
}
}
/// Whether this request is for upgrade (e.g., websocket)
pub fn is_upgrade_req(&self) -> bool {
match self {
Self::H1(s) => s.is_upgrade_req(),
Self::H2(_) => false,
}
}
/// How many response body bytes already sent
pub fn body_bytes_sent(&self) -> usize {
match self {
Self::H1(s) => s.body_bytes_sent(),
Self::H2(s) => s.body_bytes_sent(),
}
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,237 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Common functions and constants
use http::header;
use log::warn;
use pingora_http::{HMap, RequestHeader, ResponseHeader};
use std::str;
use std::time::Duration;
use super::body::BodyWriter;
use crate::utils::KVRef;
pub(super) const MAX_HEADERS: usize = 256;
pub(super) const INIT_HEADER_BUF_SIZE: usize = 4096;
pub(super) const MAX_HEADER_SIZE: usize = 1048575;
pub(super) const BODY_BUF_LIMIT: usize = 1024 * 64;
pub const CRLF: &[u8; 2] = b"\r\n";
pub const HEADER_KV_DELIMITER: &[u8; 2] = b": ";
pub(super) enum HeaderParseState {
Complete(usize),
Partial,
Invalid(httparse::Error),
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(super) enum KeepaliveStatus {
Timeout(Duration),
Infinite,
Off,
}
struct ConnectionValue {
keep_alive: bool,
upgrade: bool,
close: bool,
}
impl ConnectionValue {
fn new() -> Self {
ConnectionValue {
keep_alive: false,
upgrade: false,
close: false,
}
}
fn close(mut self) -> Self {
self.close = true;
self
}
fn upgrade(mut self) -> Self {
self.upgrade = true;
self
}
fn keep_alive(mut self) -> Self {
self.keep_alive = true;
self
}
}
fn parse_connection_header(value: &[u8]) -> ConnectionValue {
// only parse keep-alive, close, and upgrade tokens
// https://www.rfc-editor.org/rfc/rfc9110.html#section-7.6.1
const KEEP_ALIVE: &str = "keep-alive";
const CLOSE: &str = "close";
const UPGRADE: &str = "upgrade";
// fast path
if value.eq_ignore_ascii_case(CLOSE.as_bytes()) {
ConnectionValue::new().close()
} else if value.eq_ignore_ascii_case(KEEP_ALIVE.as_bytes()) {
ConnectionValue::new().keep_alive()
} else if value.eq_ignore_ascii_case(UPGRADE.as_bytes()) {
ConnectionValue::new().upgrade()
} else {
// slow path, parse the connection value
let mut close = false;
let mut upgrade = false;
let value = str::from_utf8(value).unwrap_or("");
for token in value
.split(',')
.map(|s| s.trim())
.filter(|&x| !x.is_empty())
{
if token.eq_ignore_ascii_case(CLOSE) {
close = true;
} else if token.eq_ignore_ascii_case(UPGRADE) {
upgrade = true;
}
if upgrade && close {
return ConnectionValue::new().upgrade().close();
}
}
if close {
ConnectionValue::new().close()
} else if upgrade {
ConnectionValue::new().upgrade()
} else {
ConnectionValue::new()
}
}
}
pub(crate) fn init_body_writer_comm(body_writer: &mut BodyWriter, headers: &HMap) {
let te_value = headers.get(http::header::TRANSFER_ENCODING);
if is_header_value_chunked_encoding(te_value) {
// transfer-encoding takes priority over content-length
body_writer.init_chunked();
} else {
let content_length = header_value_content_length(headers.get(http::header::CONTENT_LENGTH));
match content_length {
Some(length) => {
body_writer.init_content_length(length);
}
None => {
/* TODO: 1. connection: keepalive cannot be used,
2. mark connection must be closed */
body_writer.init_http10();
}
}
}
}
#[inline]
pub(super) fn is_header_value_chunked_encoding(
header_value: Option<&http::header::HeaderValue>,
) -> bool {
match header_value {
Some(value) => value.as_bytes().eq_ignore_ascii_case(b"chunked"),
None => false,
}
}
pub(super) fn is_upgrade_req(req: &RequestHeader) -> bool {
req.version == http::Version::HTTP_11 && req.headers.get(header::UPGRADE).is_some()
}
// Unlike the upgrade check on request, this function doesn't check the Upgrade or Connection header
// because when seeing 101, we assume the server accepts to switch protocol.
// In reality it is not common that some servers don't send all the required headers to establish
// websocket connections.
pub(super) fn is_upgrade_resp(header: &ResponseHeader) -> bool {
header.status == 101 && header.version == http::Version::HTTP_11
}
#[inline]
pub fn header_value_content_length(
header_value: Option<&http::header::HeaderValue>,
) -> Option<usize> {
match header_value {
Some(value) => buf_to_content_length(Some(value.as_bytes())),
None => None,
}
}
#[inline]
pub(super) fn buf_to_content_length(header_value: Option<&[u8]>) -> Option<usize> {
match header_value {
Some(buf) => {
match str::from_utf8(buf) {
// check valid string
Ok(str_cl_value) => match str_cl_value.parse::<i64>() {
Ok(cl_length) => {
if cl_length >= 0 {
Some(cl_length as usize)
} else {
warn!("negative content-length header value {cl_length}");
None
}
}
Err(_) => {
warn!("invalid content-length header value {str_cl_value}");
None
}
},
Err(_) => {
warn!("invalid content-length header encoding");
None
}
}
}
None => None,
}
}
#[inline]
pub(super) fn is_buf_keepalive(header_value: Option<&[u8]>) -> Option<bool> {
header_value.and_then(|value| {
let value = parse_connection_header(value);
if value.keep_alive {
Some(true)
} else if value.close {
Some(false)
} else {
None
}
})
}
#[inline]
pub(super) fn populate_headers(
base: usize,
header_ref: &mut Vec<KVRef>,
headers: &[httparse::Header],
) -> usize {
let mut used_header_index = 0;
for header in headers.iter() {
if !header.name.is_empty() {
header_ref.push(KVRef::new(
header.name.as_ptr() as usize - base,
header.name.as_bytes().len(),
header.value.as_ptr() as usize - base,
header.value.len(),
));
used_header_index += 1;
}
}
used_header_index
}

View file

@ -0,0 +1,20 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HTTP/1.x implementation
pub(crate) mod body;
pub mod client;
pub mod common;
pub mod server;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,480 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HTTP/2 client session and connection
// TODO: this module needs a refactor
use bytes::Bytes;
use h2::client::{self, ResponseFuture, SendRequest};
use h2::{Reason, RecvStream, SendStream};
use http::HeaderMap;
use log::{debug, error, warn};
use pingora_error::{Error, ErrorType, ErrorType::*, OrErr, Result, RetryType};
use pingora_http::{RequestHeader, ResponseHeader};
use pingora_timeout::timeout;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::watch;
use crate::connectors::http::v2::ConnectionRef;
use crate::protocols::Digest;
pub const PING_TIMEDOUT: ErrorType = ErrorType::new("PingTimedout");
pub struct Http2Session {
send_req: SendRequest<Bytes>,
send_body: Option<SendStream<Bytes>>,
resp_fut: Option<ResponseFuture>,
req_sent: Option<Box<RequestHeader>>,
response_header: Option<ResponseHeader>,
response_body_reader: Option<RecvStream>,
/// The read timeout, which will be applied to both reading the header and the body.
/// The timeout is reset on every read. This is not a timeout on the overall duration of the
/// response.
pub read_timeout: Option<Duration>,
pub(crate) conn: ConnectionRef,
// Indicate that whether a END_STREAM is already sent
ended: bool,
}
impl Drop for Http2Session {
fn drop(&mut self) {
self.conn.release_stream();
}
}
impl Http2Session {
pub(crate) fn new(send_req: SendRequest<Bytes>, conn: ConnectionRef) -> Self {
Http2Session {
send_req,
send_body: None,
resp_fut: None,
req_sent: None,
response_header: None,
response_body_reader: None,
read_timeout: None,
conn,
ended: false,
}
}
fn sanitize_request_header(req: &mut RequestHeader) -> Result<()> {
req.set_version(http::Version::HTTP_2);
if req.uri.authority().is_some() {
return Ok(());
}
// use host header to populate :authority field
let Some(authority) = req.headers.get(http::header::HOST).map(|v| v.as_bytes()) else {
return Error::e_explain(InvalidHTTPHeader, "no authority header for h2");
};
let uri = http::uri::Builder::new()
.scheme("https") // fixed for now
.authority(authority)
.path_and_query(req.uri.path_and_query().as_ref().unwrap().as_str())
.build();
match uri {
Ok(uri) => {
req.set_uri(uri);
Ok(())
}
Err(_) => Error::e_explain(
InvalidHTTPHeader,
format!("invalid authority from host {authority:?}"),
),
}
}
/// Write the request header to the server
pub fn write_request_header(&mut self, mut req: Box<RequestHeader>, end: bool) -> Result<()> {
if self.req_sent.is_some() {
// cannot send again, TODO: warn
return Ok(());
}
Self::sanitize_request_header(&mut req)?;
let parts = req.as_owned_parts();
let request = http::Request::from_parts(parts, ());
// There is no write timeout for h2 because the actual write happens async from this fn
let (resp_fut, send_body) = self
.send_req
.send_request(request, end)
.or_err(H2Error, "while sending request")
.map_err(|e| self.handle_err(e))?;
self.req_sent = Some(req);
self.send_body = Some(send_body);
self.resp_fut = Some(resp_fut);
self.ended = self.ended || end;
Ok(())
}
/// Write a request body chunk
pub fn write_request_body(&mut self, data: Bytes, end: bool) -> Result<()> {
if self.ended {
warn!("Try to write request body after end of stream, dropping the extra data");
return Ok(());
}
let body_writer = self
.send_body
.as_mut()
.expect("Try to write request body before sending request header");
write_body(body_writer, data, end).map_err(|e| self.handle_err(e))?;
self.ended = self.ended || end;
Ok(())
}
/// Signal that the request body has ended
pub fn finish_request_body(&mut self) -> Result<()> {
if self.ended {
return Ok(());
}
let body_writer = self
.send_body
.as_mut()
.expect("Try to finish request stream before sending request header");
// Just send an empty data frame with end of stream set
body_writer
.send_data("".into(), true)
.or_err(WriteError, "while writing empty h2 request body")
.map_err(|e| self.handle_err(e))?;
self.ended = true;
Ok(())
}
/// Read the response header
pub async fn read_response_header(&mut self) -> Result<()> {
// TODO: how to read 1xx headers?
// https://github.com/hyperium/h2/issues/167
if self.response_header.is_some() {
panic!("H2 response header is already read")
}
let Some(resp_fut) = self.resp_fut.take() else {
panic!("Try to response header is already read")
};
let res = match self.read_timeout {
Some(t) => timeout(t, resp_fut)
.await
.map_err(|_| Error::explain(ReadTimedout, "while reading h2 response header"))
.map_err(|e| self.handle_err(e))?,
None => resp_fut.await,
};
let (resp, body_reader) = res.map_err(handle_read_header_error)?.into_parts();
self.response_header = Some(resp.into());
self.response_body_reader = Some(body_reader);
Ok(())
}
/// Read the response body
///
/// `None` means, no more body to read
pub async fn read_response_body(&mut self) -> Result<Option<Bytes>> {
let Some(body_reader) = self.response_body_reader.as_mut() else {
// req is not sent or response is already read
// TODO: warn
return Ok(None);
};
if body_reader.is_end_stream() {
return Ok(None);
}
let fut = body_reader.data();
let res = match self.read_timeout {
Some(t) => timeout(t, fut)
.await
.map_err(|_| Error::explain(ReadTimedout, "while reading h2 response body"))?,
None => fut.await,
};
let body = res
.transpose()
.or_err(ReadError, "while read h2 response body")
.map_err(|mut e| {
// cannot use handle_err() because of borrow checker
if self.conn.ping_timedout() {
e.etype = PING_TIMEDOUT;
}
e
})?;
if let Some(data) = body.as_ref() {
body_reader
.flow_control()
.release_capacity(data.len())
.or_err(ReadError, "while releasing h2 response body capacity")?;
}
Ok(body)
}
/// Whether the response has ended
pub fn response_finished(&self) -> bool {
// if response_body_reader doesn't exist, the response is not even read yet
self.response_body_reader
.as_ref()
.map_or(false, |reader| reader.is_end_stream())
}
/// Read the optional trailer headers
pub async fn read_trailers(&mut self) -> Result<Option<HeaderMap>> {
let Some(reader) = self.response_body_reader.as_mut() else {
// response is not even read
// TODO: warn
return Ok(None);
};
let fut = reader.trailers();
let res = match self.read_timeout {
Some(t) => timeout(t, fut)
.await
.map_err(|_| Error::explain(ReadTimedout, "while reading h2 trailer"))
.map_err(|e| self.handle_err(e))?,
None => fut.await,
};
match res {
Ok(t) => Ok(t),
Err(e) => {
// GOAWAY with no error: this is graceful shutdown, continue as if no trailer
// RESET_STREAM with no error: https://datatracker.ietf.org/doc/html/rfc9113#section-8.1:
// this is to signal client to stop uploading request without breaking the response.
// TODO: should actually stop uploading
// TODO: should we try reading again?
// TODO: handle this when reading headers and body as well
// https://github.com/hyperium/h2/issues/741
if (e.is_go_away() || e.is_reset())
&& e.is_remote()
&& e.reason() == Some(Reason::NO_ERROR)
{
Ok(None)
} else {
Err(e)
}
}
}
.or_err(ReadError, "while reading h2 trailers")
}
/// The response header if it is already read
pub fn response_header(&self) -> Option<&ResponseHeader> {
self.response_header.as_ref()
}
/// Give up the http session abruptly.
pub fn shutdown(&mut self) {
if !self.ended || !self.response_finished() {
if let Some(send_body) = self.send_body.as_mut() {
send_body.send_reset(h2::Reason::INTERNAL_ERROR)
}
}
}
/// Drop everything in this h2 stream. Return the connection ref.
/// After this function the underlying h2 connection should already notify the closure of this
/// stream so that another stream can be created if needed.
pub(crate) fn conn(&self) -> ConnectionRef {
self.conn.clone()
}
/// Whether ping timeout occurred. After a ping timeout, the h2 connection will be terminated.
/// Ongoing h2 streams will receive an stream/connection error. The streams should check this
/// flag to tell whether the error is triggered by the timeout.
pub(crate) fn ping_timedout(&self) -> bool {
self.conn.ping_timedout()
}
/// Return the [Digest] of the connection
///
/// For reused connection, the timing in the digest will reflect its initial handshakes
/// The caller should check if the connection is reused to avoid misuse the timing field
pub fn digest(&self) -> Option<&Digest> {
Some(self.conn.digest())
}
/// the FD of the underlying connection
pub fn fd(&self) -> i32 {
self.conn.id()
}
/// take the body sender to another task to perform duplex read and write
pub fn take_request_body_writer(&mut self) -> Option<SendStream<Bytes>> {
self.send_body.take()
}
fn handle_err(&self, mut e: Box<Error>) -> Box<Error> {
if self.ping_timedout() {
e.etype = PING_TIMEDOUT;
}
e
}
}
/// A helper function to write the request body
pub fn write_body(send_body: &mut SendStream<Bytes>, data: Bytes, end: bool) -> Result<()> {
let data_len = data.len();
send_body.reserve_capacity(data_len);
send_body
.send_data(data, end)
.or_err(WriteError, "while writing h2 request body")
}
/* helper functions */
/* Types of errors during h2 header read
1. peer requests to downgrade to h1, mostly IIS server for NTLM: we will downgrade and retry
2. peer sends invalid h2 frames, usually sending h1 only header: we will downgrade and retry
3. peer sends GO_AWAY(NO_ERROR) on reused conn, usually hit http2_max_requests: we will retry
4. peer IO error on reused conn, usually firewall kills old conn: we will retry
5. All other errors will terminate the request
*/
fn handle_read_header_error(e: h2::Error) -> Box<Error> {
if e.is_remote()
&& e.reason()
.map_or(false, |r| r == h2::Reason::HTTP_1_1_REQUIRED)
{
let mut err = Error::because(H2Downgrade, "while reading h2 header", e);
err.retry = true.into();
err
} else if e.is_go_away()
&& e.is_library()
&& e.reason()
.map_or(false, |r| r == h2::Reason::PROTOCOL_ERROR)
{
// remote send invalid H2 responses
let mut err = Error::because(InvalidH2, "while reading h2 header", e);
err.retry = true.into();
err
} else if e.is_go_away()
&& e.is_remote()
&& e.reason().map_or(false, |r| r == h2::Reason::NO_ERROR)
{
// is_go_away: retry via another connection, this connection is being teardown
// only retry if the connection is reused
let mut err = Error::because(H2Error, "while reading h2 header", e);
err.retry = RetryType::ReusedOnly;
err
} else if e.is_io() {
// is_io: typical if a previously reused connection silently drops it
// only retry if the connection is reused
let true_io_error = e.get_io().unwrap().raw_os_error().is_some();
let mut err = Error::because(ReadError, "while reading h2 header", e);
if true_io_error {
err.retry = RetryType::ReusedOnly;
} // else could be TLS error, which is unsafe to retry
err
} else {
Error::because(H2Error, "while reading h2 header", e)
}
}
use tokio::sync::oneshot;
pub async fn drive_connection<S>(
mut c: client::Connection<S>,
id: i32,
closed: watch::Sender<bool>,
ping_interval: Option<Duration>,
ping_timeout_occurred: Arc<AtomicBool>,
) where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let interval = ping_interval.unwrap_or(Duration::ZERO);
if !interval.is_zero() {
// for ping to inform this fn to drop the connection
let (tx, rx) = oneshot::channel::<()>();
// for this fn to inform ping to give up when it is already dropped
let dropped = Arc::new(AtomicBool::new(false));
let dropped2 = dropped.clone();
if let Some(ping_pong) = c.ping_pong() {
pingora_runtime::current_handle().spawn(async move {
do_ping_pong(ping_pong, interval, tx, dropped2, id).await;
});
} else {
warn!("Cannot get ping-pong handler from h2 connection");
}
tokio::select! {
r = c => match r {
Ok(_) => debug!("H2 connection finished fd: {id}"),
Err(e) => debug!("H2 connection fd: {id} errored: {e:?}"),
},
r = rx => match r {
Ok(_) => {
ping_timeout_occurred.store(true, Ordering::Relaxed);
warn!("H2 connection Ping timeout/Error fd: {id}, closing conn");
},
Err(e) => warn!("H2 connection Ping Rx error {e:?}"),
},
};
dropped.store(true, Ordering::Relaxed);
} else {
match c.await {
Ok(_) => debug!("H2 connection finished fd: {id}"),
Err(e) => debug!("H2 connection fd: {id} errored: {e:?}"),
}
}
let _ = closed.send(true);
}
const PING_TIMEOUT: Duration = Duration::from_secs(5);
async fn do_ping_pong(
mut ping_pong: h2::PingPong,
interval: Duration,
tx: oneshot::Sender<()>,
dropped: Arc<AtomicBool>,
id: i32,
) {
// delay before sending the first ping, no need to race with the first request
tokio::time::sleep(interval).await;
loop {
if dropped.load(Ordering::Relaxed) {
break;
}
let ping_fut = ping_pong.ping(h2::Ping::opaque());
debug!("H2 fd: {id} ping sent");
match tokio::time::timeout(PING_TIMEOUT, ping_fut).await {
Err(_) => {
error!("H2 fd: {id} ping timeout");
let _ = tx.send(());
break;
}
Ok(r) => match r {
Ok(_) => {
debug!("H2 fd: {} pong received", id);
tokio::time::sleep(interval).await;
}
Err(e) => {
if dropped.load(Ordering::Relaxed) {
// drive_connection() exits first, no need to error again
break;
}
error!("H2 fd: {id} ping error: {e}");
let _ = tx.send(());
break;
}
},
}
}
}

View file

@ -0,0 +1,18 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HTTP/2 implementation
pub mod client;
pub mod server;

View file

@ -0,0 +1,488 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HTTP/2 server session
use bytes::Bytes;
use futures::Future;
use h2::server;
use h2::server::SendResponse;
use h2::{RecvStream, SendStream};
use http::header::HeaderName;
use http::{header, Response};
use log::{debug, warn};
use pingora_http::{RequestHeader, ResponseHeader};
use crate::protocols::http::body_buffer::FixedBuffer;
use crate::protocols::http::date::get_cached_date;
use crate::protocols::http::v1::client::http_req_header_to_wire;
use crate::protocols::http::HttpTask;
use crate::protocols::Stream;
use crate::{Error, ErrorType, OrErr, Result};
const BODY_BUF_LIMIT: usize = 1024 * 64;
type H2Connection<S> = server::Connection<S, Bytes>;
pub use h2::server::Builder as H2Options;
/// Perform HTTP/2 connection handshake with an established (TLS) connection.
///
/// The optional `options` allow to adjust certain HTTP/2 parameters and settings.
/// See [`H2Options`] for more details.
pub async fn handshake(io: Stream, options: Option<H2Options>) -> Result<H2Connection<Stream>> {
let options = options.unwrap_or_default();
let res = options.handshake(io).await;
match res {
Ok(connection) => {
debug!("H2 handshake done.");
Ok(connection)
}
Err(e) => Error::e_because(
ErrorType::HandshakeError,
"while h2 handshaking with client",
e,
),
}
}
use futures::task::Context;
use futures::task::Poll;
use std::pin::Pin;
/// The future to poll for an idle session.
///
/// Calling `.await` in this object will not return until the client decides to close this stream.
pub struct Idle<'a>(&'a mut HttpSession);
impl<'a> Future for Idle<'a> {
type Output = Result<h2::Reason>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(body_writer) = self.0.send_response_body.as_mut() {
body_writer.poll_reset(cx)
} else {
self.0.send_response.poll_reset(cx)
}
.map_err(|e| Error::because(ErrorType::H2Error, "downstream error while idling", e))
}
}
/// HTTP/2 server session
pub struct HttpSession {
request_header: RequestHeader,
request_body_reader: RecvStream,
send_response: SendResponse<Bytes>,
send_response_body: Option<SendStream<Bytes>>,
// Remember what has been written
response_written: Option<Box<ResponseHeader>>,
// Indicate that whether a END_STREAM is already sent
// in order to tell whether needs to send one extra FRAME when this response finishes
ended: bool,
// How many request body bytes have been read so far.
body_read: usize,
// How many response body bytes have been sent so far.
body_sent: usize,
// buffered request body for retry logic
retry_buffer: Option<FixedBuffer>,
}
impl HttpSession {
/// Create a new [`HttpSession`] from the HTTP/2 connection.
/// This function returns a new HTTP/2 session when the provided HTTP/2 connection, `conn`,
/// establishes a new HTTP/2 stream to this server.
///
/// Note: in order to handle all **existing** and new HTTP/2 sessions, the server must call
/// this function in a loop until the client decides to close the connection.
///
/// `None` will be returned when the connection is closing so that the loop can exit.
pub async fn from_h2_conn(conn: &mut H2Connection<Stream>) -> Result<Option<Self>> {
// NOTE: conn.accept().await is what drives the entire connection.
let res = conn.accept().await.transpose().or_err(
ErrorType::H2Error,
"while accepting new downstream requests",
)?;
Ok(res.map(|(req, send_response)| {
let (request_header, request_body_reader) = req.into_parts();
HttpSession {
request_header: request_header.into(),
request_body_reader,
send_response,
send_response_body: None,
response_written: None,
ended: false,
body_read: 0,
body_sent: 0,
retry_buffer: None,
}
}))
}
/// The request sent from the client
///
/// Different from its HTTP/1.X counterpart, this function never panics as the request is already
/// read when established a new HTTP/2 stream.
pub fn req_header(&self) -> &RequestHeader {
&self.request_header
}
/// A mutable reference to request sent from the client
///
/// Different from its HTTP/1.X counterpart, this function never panics as the request is already
/// read when established a new HTTP/2 stream.
pub fn req_header_mut(&mut self) -> &mut RequestHeader {
&mut self.request_header
}
/// Read request body bytes. `None` when there is no more body to read.
pub async fn read_body_bytes(&mut self) -> Result<Option<Bytes>> {
// TODO: timeout
let data = self.request_body_reader.data().await.transpose().or_err(
ErrorType::ReadError,
"while reading downstream request body",
)?;
if let Some(data) = data.as_ref() {
self.body_read += data.len();
if let Some(buffer) = self.retry_buffer.as_mut() {
buffer.write_to_buffer(data);
}
let _ = self
.request_body_reader
.flow_control()
.release_capacity(data.len());
}
Ok(data)
}
// the write_* don't have timeouts because the actual writing happens on the connection
// not here.
/// Write the response header to the client.
/// # the `end` flag
/// `end` marks the end of this session.
/// If the `end` flag is set, no more header or body can be sent to the client.
pub fn write_response_header(
&mut self,
mut header: Box<ResponseHeader>,
end: bool,
) -> Result<()> {
if self.ended {
// TODO: error or warn?
return Ok(());
}
// FIXME: we should ignore 1xx header because send_response() can only be called once
// https://github.com/hyperium/h2/issues/167
if let Some(resp) = self.response_written.as_ref() {
if !resp.status.is_informational() {
warn!("Respond header is already sent, cannot send again");
return Ok(());
}
}
// no need to add these headers to 1xx responses
if !header.status.is_informational() {
/* update headers */
header.insert_header(header::DATE, get_cached_date())?;
}
// remove other h1 hop headers that cannot be present in H2
// https://httpwg.org/specs/rfc7540.html#n-connection-specific-header-fields
header.remove_header(&header::TRANSFER_ENCODING);
header.remove_header(&header::CONNECTION);
header.remove_header(&header::UPGRADE);
header.remove_header(&HeaderName::from_static("keep-alive"));
header.remove_header(&HeaderName::from_static("proxy-connection"));
let resp = Response::from_parts(header.as_owned_parts(), ());
let body_writer = self.send_response.send_response(resp, end).or_err(
ErrorType::WriteError,
"while writing h2 response to downstream",
)?;
self.response_written = Some(header);
self.send_response_body = Some(body_writer);
self.ended = self.ended || end;
Ok(())
}
/// Write response body to the client. See [Self::write_response_header] for how to use `end`.
pub fn write_body(&mut self, data: Bytes, end: bool) -> Result<()> {
if self.ended {
// NOTE: in h1, we also track to see if content-length matches the data
// We have not tracked that in h2
warn!("Try to write body after end of stream, dropping the extra data");
return Ok(());
}
let Some(writer) = self.send_response_body.as_mut() else {
return Err(Error::explain(
ErrorType::H2Error,
"try to send body before header is sent",
));
};
let data_len = data.len();
writer.reserve_capacity(data_len);
writer.send_data(data, end).or_err(
ErrorType::WriteError,
"while writing h2 response body to downstream",
)?;
self.body_sent += data_len;
self.ended = self.ended || end;
Ok(())
}
/// Similar to [Self::write_response_header], this function takes a reference instead
pub fn write_response_header_ref(&mut self, header: &ResponseHeader, end: bool) -> Result<()> {
self.write_response_header(Box::new(header.clone()), end)
}
// TODO: trailer
/// Mark the session end. If no `end` flag is already set before this call, this call will
/// signal the client. Otherwise this call does nothing.
///
/// Dropping this object without sending `end` will cause an error to the client, which will cause
/// the client to treat this session as bad or incomplete.
pub fn finish(&mut self) -> Result<()> {
if self.ended {
// already ended the stream
return Ok(());
}
if let Some(writer) = self.send_response_body.as_mut() {
// use an empty data frame to signal the end
writer.send_data("".into(), true).or_err(
ErrorType::WriteError,
"while writing h2 response body to downstream",
)?;
self.ended = true;
};
// else: the response header is not sent, do nothing now.
// When send_response_body is dropped, an RST_STREAM will be sent
Ok(())
}
pub fn response_duplex_vec(&mut self, tasks: Vec<HttpTask>) -> Result<bool> {
let mut end_stream = false;
for task in tasks.into_iter() {
end_stream = match task {
HttpTask::Header(header, end) => {
self.write_response_header(header, end)
.map_err(|e| e.into_down())?;
end
}
HttpTask::Body(data, end) => match data {
Some(d) => {
if !d.is_empty() {
self.write_body(d, end).map_err(|e| e.into_down())?;
}
end
}
None => end,
},
HttpTask::Trailer(_) => true, // trailer is not supported yet
HttpTask::Done => {
self.finish().map_err(|e| e.into_down())?;
return Ok(true);
}
HttpTask::Failed(e) => {
return Err(e);
}
} || end_stream // safe guard in case `end` in tasks flips from true to false
}
Ok(end_stream)
}
/// Return a string `$METHOD $PATH $HOST`. Mostly for logging and debug purpose
pub fn request_summary(&self) -> String {
format!(
"{} {}, Host: {}",
self.request_header.method,
self.request_header.uri,
self.request_header
.headers
.get(header::HOST)
.map(|v| String::from_utf8_lossy(v.as_bytes()))
.unwrap_or_default()
)
}
/// Return the written response header. `None` if it is not written yet.
pub fn response_written(&self) -> Option<&ResponseHeader> {
self.response_written.as_deref()
}
/// Give up the stream abruptly.
///
/// This will send a `INTERNAL_ERROR` stream error to the client
pub fn shutdown(&mut self) {
if !self.ended {
self.send_response.send_reset(h2::Reason::INTERNAL_ERROR);
}
}
// This is a hack for pingora-proxy to create subrequests from h2 server session
// TODO: be able to convert from h2 to h1 subrequest
pub fn pseudo_raw_h1_request_header(&self) -> Bytes {
let buf = http_req_header_to_wire(&self.request_header).unwrap(); // safe, None only when version unknown
buf.freeze()
}
/// Whether there is no more body to read
pub fn is_body_done(&self) -> bool {
self.request_body_reader.is_end_stream()
}
/// Whether there is any body to read.
pub fn is_body_empty(&self) -> bool {
self.body_read == 0
&& (self.is_body_done()
|| self
.request_header
.headers
.get(header::CONTENT_LENGTH)
.map_or(false, |cl| cl.as_bytes() == b"0"))
}
pub fn retry_buffer_truncated(&self) -> bool {
self.retry_buffer
.as_ref()
.map_or_else(|| false, |r| r.is_truncated())
}
pub fn enable_retry_buffering(&mut self) {
if self.retry_buffer.is_none() {
self.retry_buffer = Some(FixedBuffer::new(BODY_BUF_LIMIT))
}
}
pub fn get_retry_buffer(&self) -> Option<Bytes> {
self.retry_buffer.as_ref().and_then(|b| {
if b.is_truncated() {
None
} else {
b.get_buffer()
}
})
}
/// `async fn idle() -> Result<Reason, Error>;`
/// This async fn will be pending forever until the client closes the stream/connection
/// This function is used for watching client status so that the server is able to cancel
/// its internal tasks as the client waiting for the tasks goes away
pub fn idle(&mut self) -> Idle {
Idle(self)
}
/// Similar to `read_body_bytes()` but will be pending after Ok(None) is returned,
/// until the client closes the connection
pub async fn read_body_or_idle(&mut self, no_body_expected: bool) -> Result<Option<Bytes>> {
if no_body_expected || self.is_body_done() {
let reason = self.idle().await?;
Error::e_explain(
ErrorType::H2Error,
format!("Client closed H2, reason: {reason}"),
)
} else {
self.read_body_bytes().await
}
}
/// How many response body bytes sent to the client
pub fn body_bytes_sent(&self) -> usize {
self.body_sent
}
}
#[cfg(test)]
mod test {
use super::*;
use http::{Method, Request};
use tokio::io::duplex;
#[tokio::test]
async fn test_server_handshake_accept_request() {
let (client, server) = duplex(65536);
let client_body = "test client body";
let server_body = "test server body";
tokio::spawn(async move {
let (h2, connection) = h2::client::handshake(client).await.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let mut h2 = h2.ready().await.unwrap();
let request = Request::builder()
.method(Method::GET)
.uri("https://www.example.com/")
.body(())
.unwrap();
let (response, mut req_body) = h2.send_request(request, false).unwrap();
req_body.reserve_capacity(client_body.len());
req_body.send_data(client_body.into(), true).unwrap();
let (head, mut body) = response.await.unwrap().into_parts();
assert_eq!(head.status, 200);
let data = body.data().await.unwrap().unwrap();
assert_eq!(data, server_body);
});
let mut connection = handshake(Box::new(server), None).await.unwrap();
while let Some(mut http) = HttpSession::from_h2_conn(&mut connection).await.unwrap() {
tokio::spawn(async move {
let req = http.req_header();
assert_eq!(req.method, Method::GET);
assert_eq!(req.uri, "https://www.example.com/");
http.enable_retry_buffering();
assert!(!http.is_body_empty());
assert!(!http.is_body_done());
let body = http.read_body_or_idle(false).await.unwrap().unwrap();
assert_eq!(body, client_body);
assert!(http.is_body_done());
let retry_body = http.get_retry_buffer().unwrap();
assert_eq!(retry_body, client_body);
// test idling before response header is sent
tokio::select! {
_ = http.idle() => {panic!("downstream should be idling")},
_= tokio::time::sleep(tokio::time::Duration::from_secs(1)) => {}
}
let response_header = Box::new(ResponseHeader::build(200, None).unwrap());
http.write_response_header(response_header, false).unwrap();
// test idling after response header is sent
tokio::select! {
_ = http.read_body_or_idle(false) => {panic!("downstream should be idling")},
_= tokio::time::sleep(tokio::time::Duration::from_secs(1)) => {}
}
// end: false here to verify finish() closes the stream nicely
http.write_body(server_body.into(), false).unwrap();
http.finish().unwrap();
});
}
}
}

View file

@ -0,0 +1,297 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Extensions to the regular TCP APIs
#![allow(non_camel_case_types)]
use libc::socklen_t;
#[cfg(target_os = "linux")]
use libc::{c_int, c_void};
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use std::io::{self, ErrorKind};
use std::mem;
use std::net::SocketAddr;
use std::os::unix::io::{AsRawFd, RawFd};
use std::time::Duration;
use tokio::net::{TcpSocket, TcpStream, UnixStream};
/// The (copy of) the kernel struct tcp_info returns
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct TCP_INFO {
tcpi_state: u8,
tcpi_ca_state: u8,
tcpi_retransmits: u8,
tcpi_probes: u8,
tcpi_backoff: u8,
tcpi_options: u8,
tcpi_snd_wscale_4_rcv_wscale_4: u8,
tcpi_delivery_rate_app_limited: u8,
tcpi_rto: u32,
tcpi_ato: u32,
tcpi_snd_mss: u32,
tcpi_rcv_mss: u32,
tcpi_unacked: u32,
tcpi_sacked: u32,
tcpi_lost: u32,
tcpi_retrans: u32,
tcpi_fackets: u32,
tcpi_last_data_sent: u32,
tcpi_last_ack_sent: u32,
tcpi_last_data_recv: u32,
tcpi_last_ack_recv: u32,
tcpi_pmtu: u32,
tcpi_rcv_ssthresh: u32,
pub tcpi_rtt: u32,
tcpi_rttvar: u32,
/* uncomment these field if needed
tcpi_snd_ssthresh: u32,
tcpi_snd_cwnd: u32,
tcpi_advmss: u32,
tcpi_reordering: u32,
tcpi_rcv_rtt: u32,
tcpi_rcv_space: u32,
tcpi_total_retrans: u32,
tcpi_pacing_rate: u64,
tcpi_max_pacing_rate: u64,
tcpi_bytes_acked: u64,
tcpi_bytes_received: u64,
tcpi_segs_out: u32,
tcpi_segs_in: u32,
tcpi_notsent_bytes: u32,
tcpi_min_rtt: u32,
tcpi_data_segs_in: u32,
tcpi_data_segs_out: u32,
tcpi_delivery_rate: u64,
*/
/* and more, see include/linux/tcp.h */
}
impl TCP_INFO {
/// Create a new zeroed out [`TCP_INFO`]
pub unsafe fn new() -> Self {
mem::zeroed()
}
/// Return the size of [`TCP_INFO`]
pub fn len() -> socklen_t {
mem::size_of::<Self>() as socklen_t
}
}
#[cfg(target_os = "linux")]
fn set_opt<T: Copy>(sock: c_int, opt: c_int, val: c_int, payload: T) -> io::Result<()> {
unsafe {
let payload = &payload as *const T as *const c_void;
cvt_linux_error(libc::setsockopt(
sock,
opt,
val,
payload as *const _,
mem::size_of::<T>() as socklen_t,
))?;
Ok(())
}
}
#[cfg(target_os = "linux")]
fn get_opt<T>(
sock: c_int,
opt: c_int,
val: c_int,
payload: &mut T,
size: &mut socklen_t,
) -> io::Result<()> {
unsafe {
let payload = payload as *mut T as *mut c_void;
cvt_linux_error(libc::getsockopt(sock, opt, val, payload as *mut _, size))?;
Ok(())
}
}
#[cfg(target_os = "linux")]
fn cvt_linux_error(t: i32) -> io::Result<i32> {
if t == -1 {
Err(io::Error::last_os_error())
} else {
Ok(t)
}
}
#[cfg(target_os = "linux")]
fn ip_bind_addr_no_port(fd: RawFd, val: bool) -> io::Result<()> {
const IP_BIND_ADDRESS_NO_PORT: i32 = 24;
set_opt(fd, libc::IPPROTO_IP, IP_BIND_ADDRESS_NO_PORT, val as c_int)
}
#[cfg(not(target_os = "linux"))]
fn ip_bind_addr_no_port(_fd: RawFd, _val: bool) -> io::Result<()> {
Ok(())
}
#[cfg(target_os = "linux")]
fn set_so_keepalive(fd: RawFd, val: bool) -> io::Result<()> {
set_opt(fd, libc::SOL_SOCKET, libc::SO_KEEPALIVE, val as c_int)
}
#[cfg(target_os = "linux")]
fn set_so_keepalive_idle(fd: RawFd, val: Duration) -> io::Result<()> {
set_opt(
fd,
libc::IPPROTO_TCP,
libc::TCP_KEEPIDLE,
val.as_secs() as c_int, // only the seconds part of val is used
)
}
#[cfg(target_os = "linux")]
fn set_so_keepalive_interval(fd: RawFd, val: Duration) -> io::Result<()> {
set_opt(
fd,
libc::IPPROTO_TCP,
libc::TCP_KEEPINTVL,
val.as_secs() as c_int, // only the seconds part of val is used
)
}
#[cfg(target_os = "linux")]
fn set_so_keepalive_count(fd: RawFd, val: usize) -> io::Result<()> {
set_opt(fd, libc::IPPROTO_TCP, libc::TCP_KEEPCNT, val as c_int)
}
#[cfg(target_os = "linux")]
fn set_keepalive(fd: RawFd, ka: &TcpKeepalive) -> io::Result<()> {
set_so_keepalive(fd, true)?;
set_so_keepalive_idle(fd, ka.idle)?;
set_so_keepalive_interval(fd, ka.interval)?;
set_so_keepalive_count(fd, ka.count)
}
#[cfg(not(target_os = "linux"))]
fn set_keepalive(_fd: RawFd, _ka: &TcpKeepalive) -> io::Result<()> {
Ok(())
}
/// Get the kernel TCP_INFO for the given FD.
#[cfg(target_os = "linux")]
pub fn get_tcp_info(fd: RawFd) -> io::Result<TCP_INFO> {
let mut tcp_info = unsafe { TCP_INFO::new() };
let mut data_len: socklen_t = TCP_INFO::len();
get_opt(
fd,
libc::IPPROTO_TCP,
libc::TCP_INFO,
&mut tcp_info,
&mut data_len,
)?;
if data_len != TCP_INFO::len() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"TCP_INFO struct size mismatch",
));
}
Ok(tcp_info)
}
#[cfg(not(target_os = "linux"))]
pub fn get_tcp_info(_fd: RawFd) -> io::Result<TCP_INFO> {
Ok(unsafe { TCP_INFO::new() })
}
/*
* this extention is needed until the following are addressed
* https://github.com/tokio-rs/tokio/issues/1543
* https://github.com/tokio-rs/mio/issues/1257
* https://github.com/tokio-rs/mio/issues/1211
*/
/// connect() to the given address while optionally bind to the specific source address
///
/// `IP_BIND_ADDRESS_NO_PORT` is used.
pub async fn connect(addr: &SocketAddr, bind_to: Option<&SocketAddr>) -> Result<TcpStream> {
let socket = if addr.is_ipv4() {
TcpSocket::new_v4()
} else {
TcpSocket::new_v6()
}
.or_err(SocketError, "failed to create socket")?;
if cfg!(target_os = "linux") {
ip_bind_addr_no_port(socket.as_raw_fd(), true)
.or_err(SocketError, "failed to set socket opts")?;
if let Some(baddr) = bind_to {
socket
.bind(*baddr)
.or_err_with(BindError, || format!("failed to bind to socket {}", *baddr))?;
};
}
// TODO: add support for bind on other platforms
socket
.connect(*addr)
.await
.map_err(|e| wrap_os_connect_error(e, format!("Fail to connect to {}", *addr)))
}
/// connect() to the given Unix domain socket
pub async fn connect_uds(path: &std::path::Path) -> Result<UnixStream> {
UnixStream::connect(path)
.await
.map_err(|e| wrap_os_connect_error(e, format!("Fail to connect to {}", path.display())))
}
fn wrap_os_connect_error(e: std::io::Error, context: String) -> Box<Error> {
match e.kind() {
ErrorKind::ConnectionRefused => Error::because(ConnectRefused, context, e),
ErrorKind::TimedOut => Error::because(ConnectTimedout, context, e),
ErrorKind::PermissionDenied | ErrorKind::AddrInUse | ErrorKind::AddrNotAvailable => {
Error::because(InternalError, context, e)
}
_ => match e.raw_os_error() {
Some(code) => match code {
libc::ENETUNREACH | libc::EHOSTUNREACH => {
Error::because(ConnectNoRoute, context, e)
}
_ => Error::because(ConnectError, context, e),
},
None => Error::because(ConnectError, context, e),
},
}
}
/// The configuration for TCP keepalive
#[derive(Clone, Debug)]
pub struct TcpKeepalive {
/// The time a connection needs to be idle before TCP begins sending out keep-alive probes.
pub idle: Duration,
/// The number of seconds between TCP keep-alive probes.
pub interval: Duration,
/// The maximum number of TCP keep-alive probes to send before giving up and killing the connection
pub count: usize,
}
impl std::fmt::Display for TcpKeepalive {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}/{:?}/{}", self.idle, self.interval, self.count)
}
}
/// Apply the given TCP keepalive settings to the given connection
pub fn set_tcp_keepalive(stream: &TcpStream, ka: &TcpKeepalive) -> Result<()> {
let fd = stream.as_raw_fd();
// TODO: check localhost or if keepalive is already set
set_keepalive(fd, ka).or_err(ConnectError, "failed to set keepalive")
}

View file

@ -0,0 +1,59 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Listeners
use std::io;
use std::os::unix::io::AsRawFd;
use tokio::net::{TcpListener, UnixListener};
use crate::protocols::l4::stream::Stream;
/// The type for generic listener for both TCP and Unix domain socket
#[derive(Debug)]
pub enum Listener {
Tcp(TcpListener),
Unix(UnixListener),
}
impl From<TcpListener> for Listener {
fn from(s: TcpListener) -> Self {
Self::Tcp(s)
}
}
impl From<UnixListener> for Listener {
fn from(s: UnixListener) -> Self {
Self::Unix(s)
}
}
impl AsRawFd for Listener {
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
match &self {
Self::Tcp(l) => l.as_raw_fd(),
Self::Unix(l) => l.as_raw_fd(),
}
}
}
impl Listener {
/// Accept a connection from the listening endpoint
pub async fn accept(&self) -> io::Result<Stream> {
match &self {
Self::Tcp(l) => l.accept().await.map(|(stream, _)| stream.into()),
Self::Unix(l) => l.accept().await.map(|(stream, _)| stream.into()),
}
}
}

View file

@ -0,0 +1,20 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Transport layer protocol implementation
pub mod ext;
pub mod listener;
pub mod socket;
pub mod stream;

View file

@ -0,0 +1,185 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Generic socket type
use crate::{Error, OrErr};
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::net::SocketAddr as StdSockAddr;
use std::os::unix::net::SocketAddr as StdUnixSockAddr;
/// [`SocketAddr`] is a storage type that contains either a Internet (IP address)
/// socket address or a Unix domain socket address.
#[derive(Debug, Clone)]
pub enum SocketAddr {
Inet(StdSockAddr),
Unix(StdUnixSockAddr),
}
impl SocketAddr {
/// Get a reference to the IP socket if it is one
pub fn as_inet(&self) -> Option<&StdSockAddr> {
if let SocketAddr::Inet(addr) = self {
Some(addr)
} else {
None
}
}
/// Get a reference to the Unix domain socket if it is one
pub fn as_unix(&self) -> Option<&StdUnixSockAddr> {
if let SocketAddr::Unix(addr) = self {
Some(addr)
} else {
None
}
}
/// Set the port if the address is an IP socket.
pub fn set_port(&mut self, port: u16) {
if let SocketAddr::Inet(addr) = self {
addr.set_port(port)
}
}
}
impl std::fmt::Display for SocketAddr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
SocketAddr::Inet(addr) => write!(f, "{addr}"),
SocketAddr::Unix(addr) => {
if let Some(path) = addr.as_pathname() {
write!(f, "{}", path.display())
} else {
write!(f, "{addr:?}")
}
}
}
}
}
impl Hash for SocketAddr {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Self::Inet(sockaddr) => sockaddr.hash(state),
Self::Unix(sockaddr) => {
if let Some(path) = sockaddr.as_pathname() {
// use the underlying path as the hash
path.hash(state);
} else {
// unnamed or abstract UDS
// abstract UDS name not yet exposed by std API
// panic for now, we can decide on the right way to hash them later
panic!("Unnamed and abstract UDS types not yet supported for hashing")
}
}
}
}
}
impl PartialEq for SocketAddr {
fn eq(&self, other: &Self) -> bool {
match self {
Self::Inet(addr) => Some(addr) == other.as_inet(),
Self::Unix(addr) => {
let path = addr.as_pathname();
// can only compare UDS with path, assume false on all unnamed UDS
path.is_some() && path == other.as_unix().and_then(|addr| addr.as_pathname())
}
}
}
}
impl PartialOrd for SocketAddr {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SocketAddr {
fn cmp(&self, other: &Self) -> Ordering {
match self {
Self::Inet(addr) => {
if let Some(o) = other.as_inet() {
addr.cmp(o)
} else {
// always make Inet < Unix "smallest for variants at the top"
Ordering::Less
}
}
Self::Unix(addr) => {
if let Some(o) = other.as_unix() {
// NOTE: unnamed UDS are consider the same
addr.as_pathname().cmp(&o.as_pathname())
} else {
// always make Inet < Unix "smallest for variants at the top"
Ordering::Greater
}
}
}
}
}
impl Eq for SocketAddr {}
impl std::str::FromStr for SocketAddr {
type Err = Box<Error>;
// This is very basic parsing logic, it might treat invalid IP:PORT str as UDS path
// TODO: require UDS to have some prefix
fn from_str(s: &str) -> Result<Self, Self::Err> {
match StdSockAddr::from_str(s) {
Ok(addr) => Ok(SocketAddr::Inet(addr)),
Err(_) => {
let uds_socket = StdUnixSockAddr::from_pathname(s)
.or_err(crate::BindError, "invalid UDS path")?;
Ok(SocketAddr::Unix(uds_socket))
}
}
}
}
impl std::net::ToSocketAddrs for SocketAddr {
type Iter = std::iter::Once<StdSockAddr>;
// Error if UDS addr
fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
if let Some(inet) = self.as_inet() {
Ok(std::iter::once(*inet))
} else {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
"UDS socket cannot be used as inet socket",
))
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn parse_ip() {
let ip: SocketAddr = "127.0.0.1:80".parse().unwrap();
assert!(ip.as_inet().is_some());
}
#[test]
fn parse_uds() {
let uds: SocketAddr = "/tmp/my.sock".parse().unwrap();
assert!(uds.as_unix().is_some());
}
}

View file

@ -0,0 +1,378 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Transport layer connection
use async_trait::async_trait;
use futures::FutureExt;
use log::{debug, error};
use pingora_error::{ErrorType::*, OrErr, Result};
use std::os::unix::io::AsRawFd;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::SystemTime;
use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
use tokio::net::{TcpStream, UnixStream};
use crate::protocols::raw_connect::ProxyDigest;
use crate::protocols::{GetProxyDigest, GetTimingDigest, Shutdown, Ssl, TimingDigest, UniqueID};
use crate::upstreams::peer::Tracer;
#[derive(Debug)]
enum RawStream {
Tcp(TcpStream),
Unix(UnixStream),
}
impl AsyncRead for RawStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
// Safety: Basic enum pin projection
unsafe {
match &mut Pin::get_unchecked_mut(self) {
RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
}
}
}
}
impl AsyncWrite for RawStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
// Safety: Basic enum pin projection
unsafe {
match &mut Pin::get_unchecked_mut(self) {
RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
}
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
// Safety: Basic enum pin projection
unsafe {
match &mut Pin::get_unchecked_mut(self) {
RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
}
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
// Safety: Basic enum pin projection
unsafe {
match &mut Pin::get_unchecked_mut(self) {
RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
}
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
// Safety: Basic enum pin projection
unsafe {
match &mut Pin::get_unchecked_mut(self) {
RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
}
}
}
fn is_write_vectored(&self) -> bool {
match self {
RawStream::Tcp(s) => s.is_write_vectored(),
RawStream::Unix(s) => s.is_write_vectored(),
}
}
}
// Large read buffering helps reducing syscalls with little trade-off
// Ssl layer always does "small" reads in 16k (TLS record size) so L4 read buffer helps a lot.
const BUF_READ_SIZE: usize = 64 * 1024;
// Small write buf to match MSS. Too large write buf delays real time communication.
// This buffering effectively implements something similar to Nagle's algorithm.
// The benefit is that user space can control when to flush, where Nagle's can't be controlled.
// And userspace buffering reduce both syscalls and small packets.
const BUF_WRITE_SIZE: usize = 1460;
// NOTE: with writer buffering, users need to call flush() to make sure the data is actually
// sent. Otherwise data could be stuck in the buffer forever or get lost when stream is closed.
/// A concrete type for transport layer connection + extra fields for logging
#[derive(Debug)]
pub struct Stream {
stream: BufStream<RawStream>,
buffer_write: bool,
proxy_digest: Option<Arc<ProxyDigest>>,
/// When this connection is established
pub established_ts: SystemTime,
/// The distributed tracing object for this stream
pub tracer: Option<Tracer>,
}
impl Stream {
/// set TCP nodelay for this connection if `self` is TCP
pub fn set_nodelay(&mut self) -> Result<()> {
if let RawStream::Tcp(s) = &self.stream.get_ref() {
s.set_nodelay(true)
.or_err(ConnectError, "failed to set_nodelay")?;
}
Ok(())
}
}
impl From<TcpStream> for Stream {
fn from(s: TcpStream) -> Self {
Stream {
stream: BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, RawStream::Tcp(s)),
buffer_write: true,
established_ts: SystemTime::now(),
proxy_digest: None,
tracer: None,
}
}
}
impl From<UnixStream> for Stream {
fn from(s: UnixStream) -> Self {
Stream {
stream: BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, RawStream::Unix(s)),
buffer_write: true,
established_ts: SystemTime::now(),
proxy_digest: None,
tracer: None,
}
}
}
impl UniqueID for Stream {
fn id(&self) -> i32 {
match &self.stream.get_ref() {
RawStream::Tcp(s) => s.as_raw_fd(),
RawStream::Unix(s) => s.as_raw_fd(),
}
}
}
impl Ssl for Stream {}
#[async_trait]
impl Shutdown for Stream {
async fn shutdown(&mut self) {
AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| {
debug!("Failed to shutdown connection: {:?}", e);
});
}
}
impl GetTimingDigest for Stream {
fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
let mut digest = Vec::with_capacity(2); // expect to have both L4 stream and TLS layer
digest.push(Some(TimingDigest {
established_ts: self.established_ts,
}));
digest
}
}
impl GetProxyDigest for Stream {
fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>> {
self.proxy_digest.clone()
}
fn set_proxy_digest(&mut self, digest: ProxyDigest) {
self.proxy_digest = Some(Arc::new(digest));
}
}
impl Drop for Stream {
fn drop(&mut self) {
if let Some(t) = self.tracer.as_ref() {
t.0.on_disconnected();
}
/* use nodelay/local_addr function to detect socket status */
let ret = match &self.stream.get_ref() {
RawStream::Tcp(s) => s.nodelay().err(),
RawStream::Unix(s) => s.local_addr().err(),
};
if let Some(e) = ret {
match e.kind() {
tokio::io::ErrorKind::Other => {
if let Some(ecode) = e.raw_os_error() {
if ecode == 9 {
// Or we could panic here
error!("Crit: socket {:?} is being double closed", self.stream);
}
}
}
_ => {
debug!("Socket is already broken {:?}", e);
}
}
} else {
// try flush the write buffer. We use now_or_never() because
// 1. Drop cannot be async
// 2. write should usually be ready, unless the buf is full.
let _ = self.flush().now_or_never();
}
debug!("Dropping socket {:?}", self.stream);
}
}
impl AsyncRead for Stream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}
impl AsyncWrite for Stream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.buffer_write {
Pin::new(&mut self.stream).poll_write(cx, buf)
} else {
Pin::new(&mut self.stream.get_mut()).poll_write(cx, buf)
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
if self.buffer_write {
Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
} else {
Pin::new(&mut self.stream.get_mut()).poll_write_vectored(cx, bufs)
}
}
fn is_write_vectored(&self) -> bool {
if self.buffer_write {
self.stream.is_write_vectored() // it is true
} else {
self.stream.get_ref().is_write_vectored()
}
}
}
pub mod async_write_vec {
use bytes::Buf;
use futures::ready;
use std::future::Future;
use std::io::IoSlice;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io;
use tokio::io::AsyncWrite;
/*
the missing write_buf https://github.com/tokio-rs/tokio/pull/3156#issuecomment-738207409
https://github.com/tokio-rs/tokio/issues/2610
In general vectored write is lost when accessing the trait object: Box<S: AsyncWrite>
*/
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct WriteVec<'a, W, B> {
writer: &'a mut W,
buf: &'a mut B,
}
pub trait AsyncWriteVec {
fn poll_write_vec<B: Buf>(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut B,
) -> Poll<io::Result<usize>>;
fn write_vec<'a, B>(&'a mut self, src: &'a mut B) -> WriteVec<'a, Self, B>
where
Self: Sized,
B: Buf,
{
WriteVec {
writer: self,
buf: src,
}
}
}
impl<W, B> Future for WriteVec<'_, W, B>
where
W: AsyncWriteVec + Unpin,
B: Buf,
{
type Output = io::Result<usize>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<usize>> {
let me = &mut *self;
Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf)
}
}
/* from https://github.com/tokio-rs/tokio/blob/master/tokio-util/src/lib.rs#L177 */
impl<T> AsyncWriteVec for T
where
T: AsyncWrite,
{
fn poll_write_vec<B: Buf>(
self: Pin<&mut Self>,
ctx: &mut Context,
buf: &mut B,
) -> Poll<io::Result<usize>> {
const MAX_BUFS: usize = 64;
if !buf.has_remaining() {
return Poll::Ready(Ok(0));
}
let n = if self.is_write_vectored() {
let mut slices = [IoSlice::new(&[]); MAX_BUFS];
let cnt = buf.chunks_vectored(&mut slices);
ready!(self.poll_write_vectored(ctx, &slices[..cnt]))?
} else {
ready!(self.poll_write(ctx, buf.chunk()))?
};
buf.advance(n);
Poll::Ready(Ok(n))
}
}
}
pub use async_write_vec::AsyncWriteVec;

View file

@ -0,0 +1,253 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Abstractions and implementations for protocols including TCP, TLS and HTTP
mod digest;
pub mod http;
pub mod l4;
pub mod raw_connect;
pub mod ssl;
pub use digest::{Digest, GetProxyDigest, GetTimingDigest, ProtoDigest, TimingDigest};
pub use ssl::ALPN;
use async_trait::async_trait;
use std::fmt::Debug;
use std::sync::Arc;
/// Define how a protocol should shutdown its connection.
#[async_trait]
pub trait Shutdown {
async fn shutdown(&mut self) -> ();
}
/// Define how a given session/connection identifies itself.
pub trait UniqueID {
/// The ID returned should be unique among all existing connections of the same type.
/// But ID can be recycled after a connection is shutdown.
fn id(&self) -> i32;
}
/// Interface to get TLS info
pub trait Ssl {
/// Return the TLS info if the connection is over TLS
fn get_ssl(&self) -> Option<&crate::tls::ssl::SslRef> {
None
}
/// Return the [`ssl::SslDigest`] for logging
fn get_ssl_digest(&self) -> Option<Arc<ssl::SslDigest>> {
None
}
/// Return selected ALPN if any
fn selected_alpn_proto(&self) -> Option<ALPN> {
let ssl = self.get_ssl()?;
ALPN::from_wire_selected(ssl.selected_alpn_protocol()?)
}
}
use std::any::Any;
use tokio::io::{AsyncRead, AsyncWrite};
/// The abstraction of transport layer IO
pub trait IO:
AsyncRead
+ AsyncWrite
+ Shutdown
+ UniqueID
+ Ssl
+ GetTimingDigest
+ GetProxyDigest
+ Unpin
+ Debug
+ Send
+ Sync
{
/// helper to cast as the reference of the concrete type
fn as_any(&self) -> &dyn Any;
/// helper to cast back of the concrete type
fn into_any(self: Box<Self>) -> Box<dyn Any>;
}
impl<
T: AsyncRead
+ AsyncWrite
+ Shutdown
+ UniqueID
+ Ssl
+ GetTimingDigest
+ GetProxyDigest
+ Unpin
+ Debug
+ Send
+ Sync,
> IO for T
where
T: 'static,
{
fn as_any(&self) -> &dyn Any {
self
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}
/// The type of any established transport layer connection
pub type Stream = Box<dyn IO>;
// Implement IO trait for 3rd party types, mostly for testing
mod ext_io_impl {
use super::*;
use tokio_test::io::Mock;
#[async_trait]
impl Shutdown for Mock {
async fn shutdown(&mut self) -> () {}
}
impl UniqueID for Mock {
fn id(&self) -> i32 {
0
}
}
impl Ssl for Mock {}
impl GetTimingDigest for Mock {
fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
vec![]
}
}
impl GetProxyDigest for Mock {
fn get_proxy_digest(&self) -> Option<Arc<raw_connect::ProxyDigest>> {
None
}
}
use std::io::Cursor;
#[async_trait]
impl<T: Send> Shutdown for Cursor<T> {
async fn shutdown(&mut self) -> () {}
}
impl<T> UniqueID for Cursor<T> {
fn id(&self) -> i32 {
0
}
}
impl<T> Ssl for Cursor<T> {}
impl<T> GetTimingDigest for Cursor<T> {
fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
vec![]
}
}
impl<T> GetProxyDigest for Cursor<T> {
fn get_proxy_digest(&self) -> Option<Arc<raw_connect::ProxyDigest>> {
None
}
}
use tokio::io::DuplexStream;
#[async_trait]
impl Shutdown for DuplexStream {
async fn shutdown(&mut self) -> () {}
}
impl UniqueID for DuplexStream {
fn id(&self) -> i32 {
0
}
}
impl Ssl for DuplexStream {}
impl GetTimingDigest for DuplexStream {
fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
vec![]
}
}
impl GetProxyDigest for DuplexStream {
fn get_proxy_digest(&self) -> Option<Arc<raw_connect::ProxyDigest>> {
None
}
}
}
pub(crate) trait ConnFdReusable {
fn check_fd_match<V: AsRawFd>(&self, fd: V) -> bool;
}
use l4::socket::SocketAddr;
use log::{debug, error};
use nix::sys::socket::{getpeername, SockaddrStorage, UnixAddr};
use std::{net::SocketAddr as InetSocketAddr, os::unix::prelude::AsRawFd, path::Path};
impl ConnFdReusable for SocketAddr {
fn check_fd_match<V: AsRawFd>(&self, fd: V) -> bool {
match self {
SocketAddr::Inet(addr) => addr.check_fd_match(fd),
SocketAddr::Unix(addr) => addr
.as_pathname()
.expect("non-pathname unix sockets not supported as peer")
.check_fd_match(fd),
}
}
}
impl ConnFdReusable for Path {
fn check_fd_match<V: AsRawFd>(&self, fd: V) -> bool {
let fd = fd.as_raw_fd();
match getpeername::<UnixAddr>(fd) {
Ok(peer) => match UnixAddr::new(self) {
Ok(addr) => {
if addr == peer {
debug!("Unix FD to: {peer:?} is reusable");
true
} else {
error!("Crit: unix FD mismatch: fd: {fd:?}, peer: {peer:?}, addr: {addr}",);
false
}
}
Err(e) => {
error!("Bad addr: {self:?}, error: {e:?}");
false
}
},
Err(e) => {
error!("Idle unix connection is broken: {e:?}");
false
}
}
}
}
impl ConnFdReusable for InetSocketAddr {
fn check_fd_match<V: AsRawFd>(&self, fd: V) -> bool {
let fd = fd.as_raw_fd();
match getpeername::<SockaddrStorage>(fd) {
Ok(peer) => {
let addr = SockaddrStorage::from(*self);
if addr == peer {
debug!("Inet FD to: {peer:?} is reusable");
true
} else {
error!("Crit: FD mismatch: fd: {fd:?}, addr: {addr:?}, peer: {peer:?}",);
false
}
}
Err(e) => {
debug!("Idle connection is broken: {e:?}");
false
}
}
}
}

View file

@ -0,0 +1,271 @@
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! CONNECT protocol over http 1.1 via raw Unix domain socket
//!
//! This mod implements the most rudimentary CONNECT client over raw stream.
//! The idea is to yield raw stream once the CONNECT handshake is complete
//! so that the protocol encapsulated can use the stream directly.
//! this idea only works for CONNECT over HTTP 1.1 and localhost (or where the server is close by).
use super::http::v1::client::HttpSession;
use super::http::v1::common::*;
use super::Stream;
use bytes::{BufMut, BytesMut};
use http::request::Parts as ReqHeader;
use http::Version;
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use pingora_http::ResponseHeader;
use tokio::io::AsyncWriteExt;
/// Try to establish a CONNECT proxy via the given `stream`.
///
/// `request_header` should include the necessary request headers for the CONNECT protocol.
///
/// When successful, a [`Stream`] will be returned which is the established CONNECT proxy connection.
pub async fn connect(stream: Stream, request_header: &ReqHeader) -> Result<(Stream, ProxyDigest)> {
let mut http = HttpSession::new(stream);
// We write to stream directly because HttpSession doesn't write req header in auth form
let to_wire = http_req_header_to_wire_auth_form(request_header);
http.underlying_stream
.write_all(to_wire.as_ref())
.await
.or_err(WriteError, "while writing request headers")?;
http.underlying_stream
.flush()
.await
.or_err(WriteError, "while flushing request headers")?;
// TODO: set http.read_timeout
let resp_header = http.read_resp_header_parts().await?;
Ok((
http.underlying_stream,
validate_connect_response(resp_header)?,
))
}
/// Generate the CONNECT header for the given destination
pub fn generate_connect_header<'a, H, S>(
host: &str,
port: u16,
headers: H,
) -> Result<Box<ReqHeader>>
where
S: AsRef<[u8]>,
H: Iterator<Item = (S, &'a Vec<u8>)>,
{
// TODO: valid that host doesn't have port
// TODO: support adding ad-hoc headers
let authority = format!("{host}:{port}");
let req = http::request::Builder::new()
.version(http::Version::HTTP_11)
.method(http::method::Method::CONNECT)
.uri(format!("https://{authority}/")) // scheme doesn't matter
.header(http::header::HOST, &authority);
let (mut req, _) = match req.body(()) {
Ok(r) => r.into_parts(),
Err(e) => {
return Err(e).or_err(InvalidHTTPHeader, "Invalid CONNECT request");
}
};
for (k, v) in headers {
let header_name = http::header::HeaderName::from_bytes(k.as_ref())
.or_err(InvalidHTTPHeader, "Invalid CONNECT request")?;
let header_value = http::header::HeaderValue::from_bytes(v.as_slice())
.or_err(InvalidHTTPHeader, "Invalid CONNECT request")?;
req.headers.insert(header_name, header_value);
}
Ok(Box::new(req))
}
/// The information about the CONNECT proxy.
#[derive(Debug)]
pub struct ProxyDigest {
/// The response header the proxy returns
pub response: Box<ResponseHeader>,
}
impl ProxyDigest {
pub fn new(response: Box<ResponseHeader>) -> Self {
ProxyDigest { response }
}
}
/// The error returned when the CONNECT proxy fails to establish.
#[derive(Debug)]
pub struct ConnectProxyError {
/// The response header the proxy returns
pub response: Box<ResponseHeader>,
}
impl ConnectProxyError {
pub fn boxed_new(response: Box<ResponseHeader>) -> Box<Self> {
Box::new(ConnectProxyError { response })
}
}
impl std::fmt::Display for ConnectProxyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
const PROXY_STATUS: &str = "proxy-status";
let reason = self
.response
.headers
.get(PROXY_STATUS)
.and_then(|s| s.to_str().ok())
.unwrap_or("missing proxy-status header value");
write!(
f,
"Failed CONNECT Response: status {}, proxy-status {reason}",
&self.response.status
)
}
}
impl std::error::Error for ConnectProxyError {}
#[inline]
fn http_req_header_to_wire_auth_form(req: &ReqHeader) -> BytesMut {
let mut buf = BytesMut::with_capacity(512);
// Request-Line
let method = req.method.as_str().as_bytes();
buf.put_slice(method);
buf.put_u8(b' ');
// NOTE: CONNECT doesn't need URI path so we just skip that
if let Some(path) = req.uri.authority() {
buf.put_slice(path.as_str().as_bytes());
}
buf.put_u8(b' ');
let version = match req.version {
Version::HTTP_09 => "HTTP/0.9",
Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1",
_ => "HTTP/0.9",
};
buf.put_slice(version.as_bytes());
buf.put_slice(CRLF);
// headers
let headers = &req.headers;
for (key, value) in headers.iter() {
buf.put_slice(key.as_ref());
buf.put_slice(HEADER_KV_DELIMITER);
buf.put_slice(value.as_ref());
buf.put_slice(CRLF);
}
buf.put_slice(CRLF);
buf
}
#[inline]
fn validate_connect_response(resp: Box<ResponseHeader>) -> Result<ProxyDigest> {
if !resp.status.is_success() {
return Error::e_because(
ConnectProxyFailure,
"None 2xx code",
ConnectProxyError::boxed_new(resp),
);
}
// Checking Content-Length and Transfer-Encoding is optional because we already ignore them.
// We choose to do so because we want to be strict for internal use of CONNECT.
// Ignore Content-Length header because our internal CONNECT server is coded to send it.
if resp.headers.get(http::header::TRANSFER_ENCODING).is_some() {
return Error::e_because(
ConnectProxyFailure,
"Invalid Transfer-Encoding presents",
ConnectProxyError::boxed_new(resp),
);
}
Ok(ProxyDigest::new(resp))
}
#[cfg(test)]
mod test_sync {
use super::*;
use std::collections::BTreeMap;
use tokio_test::io::Builder;
#[test]
fn test_generate_connect_header() {
let mut headers = BTreeMap::new();
headers.insert(String::from("foo"), b"bar".to_vec());
let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
assert_eq!(req.method, http::method::Method::CONNECT);
assert_eq!(req.uri.authority().unwrap(), "pingora.org:123");
assert_eq!(req.headers.get("Host").unwrap(), "pingora.org:123");
assert_eq!(req.headers.get("foo").unwrap(), "bar");
}
#[test]
fn test_request_to_wire_auth_form() {
let new_request = http::Request::builder()
.method("CONNECT")
.uri("https://pingora.org:123/")
.header("Foo", "Bar")
.body(())
.unwrap();
let (new_request, _) = new_request.into_parts();
let wire = http_req_header_to_wire_auth_form(&new_request);
assert_eq!(
&b"CONNECT pingora.org:123 HTTP/1.1\r\nfoo: Bar\r\n\r\n"[..],
&wire
);
}
#[test]
fn test_validate_connect_response() {
let resp = ResponseHeader::build(200, None).unwrap();
validate_connect_response(Box::new(resp)).unwrap();
let resp = ResponseHeader::build(404, None).unwrap();
assert!(validate_connect_response(Box::new(resp)).is_err());
let mut resp = ResponseHeader::build(200, None).unwrap();
resp.append_header("content-length", 0).unwrap();
assert!(validate_connect_response(Box::new(resp)).is_ok());
let mut resp = ResponseHeader::build(200, None).unwrap();
resp.append_header("transfer-encoding", 0).unwrap();
assert!(validate_connect_response(Box::new(resp)).is_err());
}
#[tokio::test]
async fn test_connect_write_request() {
let wire = b"CONNECT pingora.org:123 HTTP/1.1\r\nhost: pingora.org:123\r\n\r\n";
let mock_io = Box::new(Builder::new().write(wire).build());
let headers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
// ConnectionClosed
assert!(connect(mock_io, &req).await.is_err());
let to_wire = b"CONNECT pingora.org:123 HTTP/1.1\r\nhost: pingora.org:123\r\n\r\n";
let from_wire = b"HTTP/1.1 200 OK\r\n\r\n";
let mock_io = Box::new(Builder::new().write(to_wire).read(from_wire).build());
let req = generate_connect_header("pingora.org", 123, headers.iter()).unwrap();
let result = connect(mock_io, &req).await;
assert!(result.is_ok());
}
}

Some files were not shown because too many files have changed in this diff Show more