Introduce HttpsLayer::set_ssl_callback

This lets us customize the Ssl of each connection,
like set_callback which lets us customize the ConnectConfiguration
a step earlier.
This commit is contained in:
Anthony Ramine 2023-12-20 19:56:43 +01:00 committed by Anthony Ramine
parent 9b0e422c8d
commit 3637bfed2f
2 changed files with 49 additions and 11 deletions

View File

@ -7,7 +7,8 @@ use antidote::Mutex;
use boring::error::ErrorStack; use boring::error::ErrorStack;
use boring::ex_data::Index; use boring::ex_data::Index;
use boring::ssl::{ use boring::ssl::{
ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, SslSessionCacheMode, ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, SslRef,
SslSessionCacheMode,
}; };
use http::uri::Scheme; use http::uri::Scheme;
use hyper::client::connect::{Connected, Connection}; use hyper::client::connect::{Connected, Connection};
@ -41,14 +42,16 @@ fn key_index() -> Result<Index<Ssl, SessionKey>, ErrorStack> {
struct Inner { struct Inner {
ssl: SslConnector, ssl: SslConnector,
cache: Arc<Mutex<SessionCache>>, cache: Arc<Mutex<SessionCache>>,
#[allow(clippy::type_complexity)] callback: Option<Callback>,
callback: Option< ssl_callback: Option<SslCallback>,
Arc<dyn Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + Sync + Send>,
>,
} }
type Callback =
Arc<dyn Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + Sync + Send>;
type SslCallback = Arc<dyn Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + Sync + Send>;
impl Inner { impl Inner {
fn setup_ssl(&self, uri: &Uri, host: &str) -> Result<ConnectConfiguration, ErrorStack> { fn setup_ssl(&self, uri: &Uri, host: &str) -> Result<Ssl, ErrorStack> {
let mut conf = self.ssl.configure()?; let mut conf = self.ssl.configure()?;
if let Some(ref callback) = self.callback { if let Some(ref callback) = self.callback {
@ -69,7 +72,13 @@ impl Inner {
let idx = key_index()?; let idx = key_index()?;
conf.set_ex_data(idx, key); conf.set_ex_data(idx, key);
Ok(conf) let mut ssl = conf.into_ssl(host)?;
if let Some(ref ssl_callback) = self.ssl_callback {
ssl_callback(&mut ssl, uri)?;
}
Ok(ssl)
} }
} }
@ -117,17 +126,30 @@ impl HttpsLayer {
ssl: ssl.build(), ssl: ssl.build(),
cache, cache,
callback: None, callback: None,
ssl_callback: None,
}, },
}) })
} }
/// Registers a callback which can customize the configuration of each connection. /// Registers a callback which can customize the configuration of each connection.
///
/// Unsuitable to change verify hostflags (with `config.param_mut().set_hostflags(…)`),
/// as they are reset after the callback is executed. Use [`Self::set_ssl_callback`]
/// instead.
pub fn set_callback<F>(&mut self, callback: F) pub fn set_callback<F>(&mut self, callback: F)
where where
F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send, F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
{ {
self.inner.callback = Some(Arc::new(callback)); self.inner.callback = Some(Arc::new(callback));
} }
/// Registers a callback which can customize the `Ssl` of each connection.
pub fn set_ssl_callback<F>(&mut self, callback: F)
where
F: Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
{
self.inner.ssl_callback = Some(Arc::new(callback));
}
} }
impl<S> Layer<S> for HttpsLayer { impl<S> Layer<S> for HttpsLayer {
@ -182,12 +204,24 @@ where
} }
/// Registers a callback which can customize the configuration of each connection. /// Registers a callback which can customize the configuration of each connection.
///
/// Unsuitable to change verify hostflags (with `config.param_mut().set_hostflags(…)`),
/// as they are reset after the callback is executed. Use [`Self::set_ssl_callback`]
/// instead.
pub fn set_callback<F>(&mut self, callback: F) pub fn set_callback<F>(&mut self, callback: F)
where where
F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send, F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
{ {
self.inner.callback = Some(Arc::new(callback)); self.inner.callback = Some(Arc::new(callback));
} }
/// Registers a callback which can customize the `Ssl` of each connection.
pub fn set_ssl_callback<F>(&mut self, callback: F)
where
F: Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
{
self.inner.ssl_callback = Some(Arc::new(callback));
}
} }
impl<S> Service<Uri> for HttpsConnector<S> impl<S> Service<Uri> for HttpsConnector<S>
@ -244,8 +278,10 @@ where
} }
} }
let config = inner.setup_ssl(&uri, host)?; let ssl = inner.setup_ssl(&uri, host)?;
let stream = tokio_boring::connect(config, host, conn).await?; let stream = tokio_boring::SslStreamBuilder::new(ssl, conn)
.connect()
.await?;
Ok(MaybeHttpsStream::Https(stream)) Ok(MaybeHttpsStream::Https(stream))
}; };

View File

@ -139,9 +139,11 @@ async fn alpn_h2() {
let mut ssl = SslConnector::builder(SslMethod::tls()).unwrap(); let mut ssl = SslConnector::builder(SslMethod::tls()).unwrap();
ssl.set_ca_file("test/root-ca.pem").unwrap(); ssl.set_ca_file("test/root-ca.pem").unwrap();
ssl.set_alpn_protos(b"\x02h2\x08http/1.1").unwrap();
let ssl = HttpsConnector::with_connector(connector, ssl).unwrap(); let mut ssl = HttpsConnector::with_connector(connector, ssl).unwrap();
ssl.set_ssl_callback(|ssl, _| ssl.set_alpn_protos(b"\x02h2\x08http/1.1"));
let client = Client::builder().build::<_, Body>(ssl); let client = Client::builder().build::<_, Body>(ssl);
let resp = client let resp = client