Introduce HttpsLayer::set_ssl_callback

This lets us customize the Ssl of each connection,
like set_callback which lets us customize the ConnectConfiguration
a step earlier.
This commit is contained in:
Anthony Ramine 2023-12-20 19:56:43 +01:00 committed by Anthony Ramine
parent 9b0e422c8d
commit 3637bfed2f
2 changed files with 49 additions and 11 deletions

View File

@ -7,7 +7,8 @@ use antidote::Mutex;
use boring::error::ErrorStack;
use boring::ex_data::Index;
use boring::ssl::{
ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, SslSessionCacheMode,
ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, SslRef,
SslSessionCacheMode,
};
use http::uri::Scheme;
use hyper::client::connect::{Connected, Connection};
@ -41,14 +42,16 @@ fn key_index() -> Result<Index<Ssl, SessionKey>, ErrorStack> {
struct Inner {
ssl: SslConnector,
cache: Arc<Mutex<SessionCache>>,
#[allow(clippy::type_complexity)]
callback: Option<
Arc<dyn Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + Sync + Send>,
>,
callback: Option<Callback>,
ssl_callback: Option<SslCallback>,
}
type Callback =
Arc<dyn Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + Sync + Send>;
type SslCallback = Arc<dyn Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + Sync + Send>;
impl Inner {
fn setup_ssl(&self, uri: &Uri, host: &str) -> Result<ConnectConfiguration, ErrorStack> {
fn setup_ssl(&self, uri: &Uri, host: &str) -> Result<Ssl, ErrorStack> {
let mut conf = self.ssl.configure()?;
if let Some(ref callback) = self.callback {
@ -69,7 +72,13 @@ impl Inner {
let idx = key_index()?;
conf.set_ex_data(idx, key);
Ok(conf)
let mut ssl = conf.into_ssl(host)?;
if let Some(ref ssl_callback) = self.ssl_callback {
ssl_callback(&mut ssl, uri)?;
}
Ok(ssl)
}
}
@ -117,17 +126,30 @@ impl HttpsLayer {
ssl: ssl.build(),
cache,
callback: None,
ssl_callback: None,
},
})
}
/// Registers a callback which can customize the configuration of each connection.
///
/// Unsuitable to change verify hostflags (with `config.param_mut().set_hostflags(…)`),
/// as they are reset after the callback is executed. Use [`Self::set_ssl_callback`]
/// instead.
pub fn set_callback<F>(&mut self, callback: F)
where
F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
{
self.inner.callback = Some(Arc::new(callback));
}
/// Registers a callback which can customize the `Ssl` of each connection.
pub fn set_ssl_callback<F>(&mut self, callback: F)
where
F: Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
{
self.inner.ssl_callback = Some(Arc::new(callback));
}
}
impl<S> Layer<S> for HttpsLayer {
@ -182,12 +204,24 @@ where
}
/// Registers a callback which can customize the configuration of each connection.
///
/// Unsuitable to change verify hostflags (with `config.param_mut().set_hostflags(…)`),
/// as they are reset after the callback is executed. Use [`Self::set_ssl_callback`]
/// instead.
pub fn set_callback<F>(&mut self, callback: F)
where
F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
{
self.inner.callback = Some(Arc::new(callback));
}
/// Registers a callback which can customize the `Ssl` of each connection.
pub fn set_ssl_callback<F>(&mut self, callback: F)
where
F: Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
{
self.inner.ssl_callback = Some(Arc::new(callback));
}
}
impl<S> Service<Uri> for HttpsConnector<S>
@ -244,8 +278,10 @@ where
}
}
let config = inner.setup_ssl(&uri, host)?;
let stream = tokio_boring::connect(config, host, conn).await?;
let ssl = inner.setup_ssl(&uri, host)?;
let stream = tokio_boring::SslStreamBuilder::new(ssl, conn)
.connect()
.await?;
Ok(MaybeHttpsStream::Https(stream))
};

View File

@ -139,9 +139,11 @@ async fn alpn_h2() {
let mut ssl = SslConnector::builder(SslMethod::tls()).unwrap();
ssl.set_ca_file("test/root-ca.pem").unwrap();
ssl.set_alpn_protos(b"\x02h2\x08http/1.1").unwrap();
let ssl = HttpsConnector::with_connector(connector, ssl).unwrap();
let mut ssl = HttpsConnector::with_connector(connector, ssl).unwrap();
ssl.set_ssl_callback(|ssl, _| ssl.set_alpn_protos(b"\x02h2\x08http/1.1"));
let client = Client::builder().build::<_, Body>(ssl);
let resp = client