Introduce set_async_get_session_callback
This commit is contained in:
parent
8a26577b5d
commit
1ca7f76607
|
|
@ -1,6 +1,7 @@
|
|||
use boring::ex_data::Index;
|
||||
use boring::ssl::{self, ClientHello, PrivateKeyMethod, Ssl, SslContextBuilder};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::convert::identity;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{ready, Context, Poll, Waker};
|
||||
|
|
@ -19,6 +20,12 @@ pub type BoxPrivateKeyMethodFuture =
|
|||
pub type BoxPrivateKeyMethodFinish =
|
||||
Box<dyn FnOnce(&mut ssl::SslRef, &mut [u8]) -> Result<usize, AsyncPrivateKeyMethodError>>;
|
||||
|
||||
/// The type of futures to pass to [`SslContextBuilderExt::set_async_get_session_callback`].
|
||||
pub type BoxGetSessionFuture = ExDataFuture<Option<BoxGetSessionFinish>>;
|
||||
|
||||
/// The type of callbacks returned by [`BoxSelectCertFuture`] methods.
|
||||
pub type BoxGetSessionFinish = Box<dyn FnOnce(&mut ssl::SslRef, &[u8]) -> Option<ssl::SslSession>>;
|
||||
|
||||
/// Convenience alias for futures stored in [`Ssl`] ex data by [`SslContextBuilderExt`] methods.
|
||||
///
|
||||
/// Public for documentation purposes.
|
||||
|
|
@ -31,6 +38,8 @@ pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy<Index<Ssl, Option<BoxSelectCert
|
|||
pub(crate) static SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX: Lazy<
|
||||
Index<Ssl, Option<BoxPrivateKeyMethodFuture>>,
|
||||
> = Lazy::new(|| Ssl::new_ex_index().unwrap());
|
||||
pub(crate) static SELECT_GET_SESSION_FUTURE_INDEX: Lazy<Index<Ssl, Option<BoxGetSessionFuture>>> =
|
||||
Lazy::new(|| Ssl::new_ex_index().unwrap());
|
||||
|
||||
/// Extensions to [`SslContextBuilder`].
|
||||
///
|
||||
|
|
@ -57,6 +66,23 @@ pub trait SslContextBuilderExt: private::Sealed {
|
|||
///
|
||||
/// See [`AsyncPrivateKeyMethod`] for more details.
|
||||
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod);
|
||||
|
||||
/// Sets a callback that is called when a client proposed to resume a session
|
||||
/// but it was not found in the internal cache.
|
||||
///
|
||||
/// The callback is passed a reference to the session ID provided by the client.
|
||||
/// It should return the session corresponding to that ID if available. This is
|
||||
/// only used for servers, not clients.
|
||||
///
|
||||
/// See [`SslContextBuilder::set_get_session_callback`] for the sync setter
|
||||
/// of this callback.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The returned [`SslSession`] must not be associated with a different [`SslContext`].
|
||||
unsafe fn set_async_get_session_callback<F>(&mut self, callback: F)
|
||||
where
|
||||
F: Fn(&mut ssl::SslRef, &[u8]) -> Option<BoxGetSessionFuture> + Send + Sync + 'static;
|
||||
}
|
||||
|
||||
impl SslContextBuilderExt for SslContextBuilder {
|
||||
|
|
@ -73,6 +99,7 @@ impl SslContextBuilderExt for SslContextBuilder {
|
|||
*SELECT_CERT_FUTURE_INDEX,
|
||||
ClientHello::ssl_mut,
|
||||
&callback,
|
||||
identity,
|
||||
);
|
||||
|
||||
let fut_result = match fut_poll_result {
|
||||
|
|
@ -89,6 +116,29 @@ impl SslContextBuilderExt for SslContextBuilder {
|
|||
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) {
|
||||
self.set_private_key_method(AsyncPrivateKeyMethodBridge(Box::new(method)));
|
||||
}
|
||||
|
||||
unsafe fn set_async_get_session_callback<F>(&mut self, callback: F)
|
||||
where
|
||||
F: Fn(&mut ssl::SslRef, &[u8]) -> Option<BoxGetSessionFuture> + Send + Sync + 'static,
|
||||
{
|
||||
let async_callback = move |ssl: &mut ssl::SslRef, id: &[u8]| {
|
||||
let fut_poll_result = with_ex_data_future(
|
||||
&mut *ssl,
|
||||
*SELECT_GET_SESSION_FUTURE_INDEX,
|
||||
|ssl| ssl,
|
||||
|ssl| callback(ssl, id).ok_or(()),
|
||||
|option| option.ok_or(()),
|
||||
);
|
||||
|
||||
match fut_poll_result {
|
||||
Poll::Ready(Err(())) => Ok(None),
|
||||
Poll::Ready(Ok(finish)) => Ok(finish(ssl, id)),
|
||||
Poll::Pending => Err(ssl::GetSessionPendingError),
|
||||
}
|
||||
};
|
||||
|
||||
self.set_get_session_callback(async_callback)
|
||||
}
|
||||
}
|
||||
|
||||
/// A fatal error to be returned from async select certificate callbacks.
|
||||
|
|
@ -201,6 +251,7 @@ fn with_private_key_method(
|
|||
*SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX,
|
||||
|ssl| ssl,
|
||||
|ssl| create_fut(ssl, output),
|
||||
identity,
|
||||
);
|
||||
|
||||
let fut_result = match fut_poll_result {
|
||||
|
|
@ -217,11 +268,12 @@ fn with_private_key_method(
|
|||
///
|
||||
/// This function won't even bother storing the future in `index` if the future
|
||||
/// created by `create_fut` returns `Poll::Ready(_)` on the first poll call.
|
||||
fn with_ex_data_future<H, T, E>(
|
||||
fn with_ex_data_future<H, R, T, E>(
|
||||
ssl_handle: &mut H,
|
||||
index: Index<ssl::Ssl, Option<ExDataFuture<Result<T, E>>>>,
|
||||
index: Index<ssl::Ssl, Option<ExDataFuture<R>>>,
|
||||
get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef,
|
||||
create_fut: impl FnOnce(&mut H) -> Result<ExDataFuture<Result<T, E>>, E>,
|
||||
create_fut: impl FnOnce(&mut H) -> Result<ExDataFuture<R>, E>,
|
||||
into_result: impl Fn(R) -> Result<T, E>,
|
||||
) -> Poll<Result<T, E>> {
|
||||
let ssl = get_ssl_mut(ssl_handle);
|
||||
let waker = ssl
|
||||
|
|
@ -233,7 +285,7 @@ fn with_ex_data_future<H, T, E>(
|
|||
let mut ctx = Context::from_waker(&waker);
|
||||
|
||||
if let Some(data @ Some(_)) = ssl.ex_data_mut(index) {
|
||||
let fut_result = ready!(data.as_mut().unwrap().as_mut().poll(&mut ctx));
|
||||
let fut_result = into_result(ready!(data.as_mut().unwrap().as_mut().poll(&mut ctx)));
|
||||
|
||||
*data = None;
|
||||
|
||||
|
|
@ -242,7 +294,7 @@ fn with_ex_data_future<H, T, E>(
|
|||
let mut fut = create_fut(ssl_handle)?;
|
||||
|
||||
match fut.as_mut().poll(&mut ctx) {
|
||||
Poll::Ready(fut_result) => Poll::Ready(fut_result),
|
||||
Poll::Ready(fut_result) => Poll::Ready(into_result(fut_result)),
|
||||
Poll::Pending => {
|
||||
get_ssl_mut(ssl_handle).set_ex_data(index, Some(fut));
|
||||
|
||||
|
|
|
|||
|
|
@ -32,9 +32,9 @@ mod bridge;
|
|||
|
||||
use self::async_callbacks::TASK_WAKER_INDEX;
|
||||
pub use self::async_callbacks::{
|
||||
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError,
|
||||
BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish, BoxSelectCertFuture,
|
||||
ExDataFuture, SslContextBuilderExt,
|
||||
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
|
||||
BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
|
||||
BoxSelectCertFuture, ExDataFuture, SslContextBuilderExt,
|
||||
};
|
||||
use self::bridge::AsyncStreamBridge;
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,104 @@
|
|||
use boring::ssl::{SslOptions, SslRef, SslSession, SslSessionCacheMode, SslVersion};
|
||||
use futures::future;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::OnceLock;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::task::yield_now;
|
||||
use tokio_boring::{BoxGetSessionFinish, SslContextBuilderExt};
|
||||
|
||||
mod common;
|
||||
|
||||
use self::common::{create_acceptor, create_connector, create_listener};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test() {
|
||||
static FOUND_SESSION: AtomicBool = AtomicBool::new(false);
|
||||
static SERVER_SESSION_DER: OnceLock<Vec<u8>> = OnceLock::new();
|
||||
static CLIENT_SESSION_DER: OnceLock<Vec<u8>> = OnceLock::new();
|
||||
|
||||
let (listener, addr) = create_listener();
|
||||
|
||||
let acceptor = create_acceptor(move |builder| {
|
||||
builder
|
||||
.set_max_proto_version(Some(SslVersion::TLS1_2))
|
||||
.unwrap();
|
||||
builder.set_options(SslOptions::NO_TICKET);
|
||||
builder
|
||||
.set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL);
|
||||
builder.set_new_session_callback(|_, session| {
|
||||
SERVER_SESSION_DER.set(session.to_der().unwrap()).unwrap()
|
||||
});
|
||||
|
||||
unsafe {
|
||||
builder.set_async_get_session_callback(|_, _| {
|
||||
let Some(der) = SERVER_SESSION_DER.get() else {
|
||||
return None;
|
||||
};
|
||||
|
||||
Some(Box::pin(async move {
|
||||
yield_now().await;
|
||||
|
||||
FOUND_SESSION.store(true, Ordering::SeqCst);
|
||||
|
||||
Some(Box::new(|_: &mut SslRef, _: &[u8]| {
|
||||
Some(SslSession::from_der(der).unwrap())
|
||||
}) as BoxGetSessionFinish)
|
||||
}))
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let connector = create_connector(|builder| {
|
||||
builder.set_session_cache_mode(SslSessionCacheMode::CLIENT);
|
||||
builder.set_new_session_callback(|_, session| {
|
||||
CLIENT_SESSION_DER.set(session.to_der().unwrap()).unwrap()
|
||||
});
|
||||
|
||||
builder.set_ca_file("tests/cert.pem")
|
||||
});
|
||||
|
||||
let server = async move {
|
||||
tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(SERVER_SESSION_DER.get().is_some());
|
||||
assert!(!FOUND_SESSION.load(Ordering::SeqCst));
|
||||
|
||||
tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(FOUND_SESSION.load(Ordering::SeqCst));
|
||||
};
|
||||
|
||||
let client = async move {
|
||||
tokio_boring::connect(
|
||||
connector.configure().unwrap(),
|
||||
"localhost",
|
||||
TcpStream::connect(&addr).await.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let der = CLIENT_SESSION_DER.get().unwrap();
|
||||
|
||||
let mut config = connector.configure().unwrap();
|
||||
|
||||
unsafe {
|
||||
config
|
||||
.set_session(&SslSession::from_der(der).unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
tokio_boring::connect(
|
||||
config,
|
||||
"localhost",
|
||||
TcpStream::connect(&addr).await.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
};
|
||||
|
||||
future::join(server, client).await;
|
||||
}
|
||||
|
|
@ -17,27 +17,10 @@ pub(crate) fn create_server(
|
|||
impl Future<Output = Result<SslStream<TcpStream>, HandshakeError<TcpStream>>>,
|
||||
SocketAddr,
|
||||
) {
|
||||
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
|
||||
listener.set_nonblocking(true).unwrap();
|
||||
|
||||
let listener = TcpListener::from_std(listener).unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let (listener, addr) = create_listener();
|
||||
|
||||
let server = async move {
|
||||
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
|
||||
|
||||
acceptor
|
||||
.set_private_key_file("tests/key.pem", SslFiletype::PEM)
|
||||
.unwrap();
|
||||
|
||||
acceptor
|
||||
.set_certificate_chain_file("tests/cert.pem")
|
||||
.unwrap();
|
||||
|
||||
setup(&mut acceptor);
|
||||
|
||||
let acceptor = acceptor.build();
|
||||
let acceptor = create_acceptor(setup);
|
||||
|
||||
let stream = listener.accept().await.unwrap().0;
|
||||
|
||||
|
|
@ -47,21 +30,54 @@ pub(crate) fn create_server(
|
|||
(server, addr)
|
||||
}
|
||||
|
||||
pub(crate) fn create_listener() -> (TcpListener, SocketAddr) {
|
||||
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
|
||||
listener.set_nonblocking(true).unwrap();
|
||||
|
||||
let listener = TcpListener::from_std(listener).unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
(listener, addr)
|
||||
}
|
||||
|
||||
pub(crate) fn create_acceptor(setup: impl FnOnce(&mut SslAcceptorBuilder)) -> SslAcceptor {
|
||||
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
|
||||
|
||||
acceptor
|
||||
.set_private_key_file("tests/key.pem", SslFiletype::PEM)
|
||||
.unwrap();
|
||||
|
||||
acceptor
|
||||
.set_certificate_chain_file("tests/cert.pem")
|
||||
.unwrap();
|
||||
|
||||
setup(&mut acceptor);
|
||||
|
||||
acceptor.build()
|
||||
}
|
||||
|
||||
pub(crate) async fn connect(
|
||||
addr: SocketAddr,
|
||||
setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
|
||||
) -> Result<SslStream<TcpStream>, HandshakeError<TcpStream>> {
|
||||
let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
|
||||
|
||||
setup(&mut connector).unwrap();
|
||||
|
||||
let config = connector.build().configure().unwrap();
|
||||
let config = create_connector(setup).configure().unwrap();
|
||||
|
||||
let stream = TcpStream::connect(&addr).await.unwrap();
|
||||
|
||||
tokio_boring::connect(config, "localhost", stream).await
|
||||
}
|
||||
|
||||
pub(crate) fn create_connector(
|
||||
setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
|
||||
) -> SslConnector {
|
||||
let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
|
||||
|
||||
setup(&mut connector).unwrap();
|
||||
|
||||
connector.build()
|
||||
}
|
||||
|
||||
pub(crate) async fn with_trivial_client_server_exchange(
|
||||
server_setup: impl FnOnce(&mut SslAcceptorBuilder),
|
||||
) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue