Introduce set_async_get_session_callback

This commit is contained in:
Anthony Ramine 2023-10-12 15:58:57 +02:00 committed by Alessandro Ghedini
parent 8a26577b5d
commit 1ca7f76607
4 changed files with 204 additions and 32 deletions

View File

@ -1,6 +1,7 @@
use boring::ex_data::Index;
use boring::ssl::{self, ClientHello, PrivateKeyMethod, Ssl, SslContextBuilder};
use once_cell::sync::Lazy;
use std::convert::identity;
use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll, Waker};
@ -19,6 +20,12 @@ pub type BoxPrivateKeyMethodFuture =
pub type BoxPrivateKeyMethodFinish =
Box<dyn FnOnce(&mut ssl::SslRef, &mut [u8]) -> Result<usize, AsyncPrivateKeyMethodError>>;
/// The type of futures to pass to [`SslContextBuilderExt::set_async_get_session_callback`].
pub type BoxGetSessionFuture = ExDataFuture<Option<BoxGetSessionFinish>>;
/// The type of callbacks returned by [`BoxSelectCertFuture`] methods.
pub type BoxGetSessionFinish = Box<dyn FnOnce(&mut ssl::SslRef, &[u8]) -> Option<ssl::SslSession>>;
/// Convenience alias for futures stored in [`Ssl`] ex data by [`SslContextBuilderExt`] methods.
///
/// Public for documentation purposes.
@ -31,6 +38,8 @@ pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy<Index<Ssl, Option<BoxSelectCert
pub(crate) static SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX: Lazy<
Index<Ssl, Option<BoxPrivateKeyMethodFuture>>,
> = Lazy::new(|| Ssl::new_ex_index().unwrap());
pub(crate) static SELECT_GET_SESSION_FUTURE_INDEX: Lazy<Index<Ssl, Option<BoxGetSessionFuture>>> =
Lazy::new(|| Ssl::new_ex_index().unwrap());
/// Extensions to [`SslContextBuilder`].
///
@ -57,6 +66,23 @@ pub trait SslContextBuilderExt: private::Sealed {
///
/// See [`AsyncPrivateKeyMethod`] for more details.
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod);
/// Sets a callback that is called when a client proposed to resume a session
/// but it was not found in the internal cache.
///
/// The callback is passed a reference to the session ID provided by the client.
/// It should return the session corresponding to that ID if available. This is
/// only used for servers, not clients.
///
/// See [`SslContextBuilder::set_get_session_callback`] for the sync setter
/// of this callback.
///
/// # Safety
///
/// The returned [`SslSession`] must not be associated with a different [`SslContext`].
unsafe fn set_async_get_session_callback<F>(&mut self, callback: F)
where
F: Fn(&mut ssl::SslRef, &[u8]) -> Option<BoxGetSessionFuture> + Send + Sync + 'static;
}
impl SslContextBuilderExt for SslContextBuilder {
@ -73,6 +99,7 @@ impl SslContextBuilderExt for SslContextBuilder {
*SELECT_CERT_FUTURE_INDEX,
ClientHello::ssl_mut,
&callback,
identity,
);
let fut_result = match fut_poll_result {
@ -89,6 +116,29 @@ impl SslContextBuilderExt for SslContextBuilder {
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) {
self.set_private_key_method(AsyncPrivateKeyMethodBridge(Box::new(method)));
}
unsafe fn set_async_get_session_callback<F>(&mut self, callback: F)
where
F: Fn(&mut ssl::SslRef, &[u8]) -> Option<BoxGetSessionFuture> + Send + Sync + 'static,
{
let async_callback = move |ssl: &mut ssl::SslRef, id: &[u8]| {
let fut_poll_result = with_ex_data_future(
&mut *ssl,
*SELECT_GET_SESSION_FUTURE_INDEX,
|ssl| ssl,
|ssl| callback(ssl, id).ok_or(()),
|option| option.ok_or(()),
);
match fut_poll_result {
Poll::Ready(Err(())) => Ok(None),
Poll::Ready(Ok(finish)) => Ok(finish(ssl, id)),
Poll::Pending => Err(ssl::GetSessionPendingError),
}
};
self.set_get_session_callback(async_callback)
}
}
/// A fatal error to be returned from async select certificate callbacks.
@ -201,6 +251,7 @@ fn with_private_key_method(
*SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX,
|ssl| ssl,
|ssl| create_fut(ssl, output),
identity,
);
let fut_result = match fut_poll_result {
@ -217,11 +268,12 @@ fn with_private_key_method(
///
/// This function won't even bother storing the future in `index` if the future
/// created by `create_fut` returns `Poll::Ready(_)` on the first poll call.
fn with_ex_data_future<H, T, E>(
fn with_ex_data_future<H, R, T, E>(
ssl_handle: &mut H,
index: Index<ssl::Ssl, Option<ExDataFuture<Result<T, E>>>>,
index: Index<ssl::Ssl, Option<ExDataFuture<R>>>,
get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef,
create_fut: impl FnOnce(&mut H) -> Result<ExDataFuture<Result<T, E>>, E>,
create_fut: impl FnOnce(&mut H) -> Result<ExDataFuture<R>, E>,
into_result: impl Fn(R) -> Result<T, E>,
) -> Poll<Result<T, E>> {
let ssl = get_ssl_mut(ssl_handle);
let waker = ssl
@ -233,7 +285,7 @@ fn with_ex_data_future<H, T, E>(
let mut ctx = Context::from_waker(&waker);
if let Some(data @ Some(_)) = ssl.ex_data_mut(index) {
let fut_result = ready!(data.as_mut().unwrap().as_mut().poll(&mut ctx));
let fut_result = into_result(ready!(data.as_mut().unwrap().as_mut().poll(&mut ctx)));
*data = None;
@ -242,7 +294,7 @@ fn with_ex_data_future<H, T, E>(
let mut fut = create_fut(ssl_handle)?;
match fut.as_mut().poll(&mut ctx) {
Poll::Ready(fut_result) => Poll::Ready(fut_result),
Poll::Ready(fut_result) => Poll::Ready(into_result(fut_result)),
Poll::Pending => {
get_ssl_mut(ssl_handle).set_ex_data(index, Some(fut));

View File

@ -32,9 +32,9 @@ mod bridge;
use self::async_callbacks::TASK_WAKER_INDEX;
pub use self::async_callbacks::{
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError,
BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish, BoxSelectCertFuture,
ExDataFuture, SslContextBuilderExt,
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
BoxSelectCertFuture, ExDataFuture, SslContextBuilderExt,
};
use self::bridge::AsyncStreamBridge;

View File

@ -0,0 +1,104 @@
use boring::ssl::{SslOptions, SslRef, SslSession, SslSessionCacheMode, SslVersion};
use futures::future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::OnceLock;
use tokio::net::TcpStream;
use tokio::task::yield_now;
use tokio_boring::{BoxGetSessionFinish, SslContextBuilderExt};
mod common;
use self::common::{create_acceptor, create_connector, create_listener};
#[tokio::test]
async fn test() {
static FOUND_SESSION: AtomicBool = AtomicBool::new(false);
static SERVER_SESSION_DER: OnceLock<Vec<u8>> = OnceLock::new();
static CLIENT_SESSION_DER: OnceLock<Vec<u8>> = OnceLock::new();
let (listener, addr) = create_listener();
let acceptor = create_acceptor(move |builder| {
builder
.set_max_proto_version(Some(SslVersion::TLS1_2))
.unwrap();
builder.set_options(SslOptions::NO_TICKET);
builder
.set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL);
builder.set_new_session_callback(|_, session| {
SERVER_SESSION_DER.set(session.to_der().unwrap()).unwrap()
});
unsafe {
builder.set_async_get_session_callback(|_, _| {
let Some(der) = SERVER_SESSION_DER.get() else {
return None;
};
Some(Box::pin(async move {
yield_now().await;
FOUND_SESSION.store(true, Ordering::SeqCst);
Some(Box::new(|_: &mut SslRef, _: &[u8]| {
Some(SslSession::from_der(der).unwrap())
}) as BoxGetSessionFinish)
}))
});
}
});
let connector = create_connector(|builder| {
builder.set_session_cache_mode(SslSessionCacheMode::CLIENT);
builder.set_new_session_callback(|_, session| {
CLIENT_SESSION_DER.set(session.to_der().unwrap()).unwrap()
});
builder.set_ca_file("tests/cert.pem")
});
let server = async move {
tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0)
.await
.unwrap();
assert!(SERVER_SESSION_DER.get().is_some());
assert!(!FOUND_SESSION.load(Ordering::SeqCst));
tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0)
.await
.unwrap();
assert!(FOUND_SESSION.load(Ordering::SeqCst));
};
let client = async move {
tokio_boring::connect(
connector.configure().unwrap(),
"localhost",
TcpStream::connect(&addr).await.unwrap(),
)
.await
.unwrap();
let der = CLIENT_SESSION_DER.get().unwrap();
let mut config = connector.configure().unwrap();
unsafe {
config
.set_session(&SslSession::from_der(der).unwrap())
.unwrap();
}
tokio_boring::connect(
config,
"localhost",
TcpStream::connect(&addr).await.unwrap(),
)
.await
.unwrap();
};
future::join(server, client).await;
}

View File

@ -17,6 +17,20 @@ pub(crate) fn create_server(
impl Future<Output = Result<SslStream<TcpStream>, HandshakeError<TcpStream>>>,
SocketAddr,
) {
let (listener, addr) = create_listener();
let server = async move {
let acceptor = create_acceptor(setup);
let stream = listener.accept().await.unwrap().0;
tokio_boring::accept(&acceptor, stream).await
};
(server, addr)
}
pub(crate) fn create_listener() -> (TcpListener, SocketAddr) {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
listener.set_nonblocking(true).unwrap();
@ -24,7 +38,10 @@ pub(crate) fn create_server(
let listener = TcpListener::from_std(listener).unwrap();
let addr = listener.local_addr().unwrap();
let server = async move {
(listener, addr)
}
pub(crate) fn create_acceptor(setup: impl FnOnce(&mut SslAcceptorBuilder)) -> SslAcceptor {
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
acceptor
@ -37,31 +54,30 @@ pub(crate) fn create_server(
setup(&mut acceptor);
let acceptor = acceptor.build();
let stream = listener.accept().await.unwrap().0;
tokio_boring::accept(&acceptor, stream).await
};
(server, addr)
acceptor.build()
}
pub(crate) async fn connect(
addr: SocketAddr,
setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
) -> Result<SslStream<TcpStream>, HandshakeError<TcpStream>> {
let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
setup(&mut connector).unwrap();
let config = connector.build().configure().unwrap();
let config = create_connector(setup).configure().unwrap();
let stream = TcpStream::connect(&addr).await.unwrap();
tokio_boring::connect(config, "localhost", stream).await
}
pub(crate) fn create_connector(
setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
) -> SslConnector {
let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
setup(&mut connector).unwrap();
connector.build()
}
pub(crate) async fn with_trivial_client_server_exchange(
server_setup: impl FnOnce(&mut SslAcceptorBuilder),
) {