Allow returning GetSessionPendingError from get session callbacks

This commit is contained in:
Anthony Ramine 2023-10-12 15:31:35 +02:00 committed by Alessandro Ghedini
parent 1e2a4812d2
commit 8a26577b5d
3 changed files with 78 additions and 15 deletions

View File

@ -1,9 +1,9 @@
#![forbid(unsafe_op_in_unsafe_fn)]
use super::{
AlpnError, ClientHello, PrivateKeyMethod, PrivateKeyMethodError, SelectCertError, SniError,
Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef,
SslSignatureAlgorithm, SESSION_CTX_INDEX,
AlpnError, ClientHello, GetSessionPendingError, PrivateKeyMethod, PrivateKeyMethodError,
SelectCertError, SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession,
SslSessionRef, SslSignatureAlgorithm, SESSION_CTX_INDEX,
};
use crate::error::ErrorStack;
use crate::ffi;
@ -13,7 +13,6 @@ use foreign_types::ForeignTypeRef;
use libc::c_char;
use libc::{c_int, c_uchar, c_uint, c_void};
use std::ffi::CStr;
use std::mem;
use std::ptr;
use std::slice;
use std::str;
@ -323,7 +322,10 @@ pub(super) unsafe extern "C" fn raw_get_session<F>(
copy: *mut c_int,
) -> *mut ffi::SSL_SESSION
where
F: Fn(&mut SslRef, &[u8]) -> Option<SslSession> + 'static + Sync + Send,
F: Fn(&mut SslRef, &[u8]) -> Result<Option<SslSession>, GetSessionPendingError>
+ 'static
+ Sync
+ Send,
{
// SAFETY: boring provides valid inputs.
let ssl = unsafe { SslRef::from_ptr_mut(ssl) };
@ -342,13 +344,15 @@ where
let callback = unsafe { &*(callback as *const F) };
match callback(ssl, data) {
Some(session) => {
let p = session.as_ptr();
mem::forget(session);
Ok(Some(session)) => {
let p = session.into_ptr();
*copy = 0;
p
}
None => ptr::null_mut(),
Ok(None) => ptr::null_mut(),
Err(GetSessionPendingError) => unsafe { ffi::SSL_magic_pending_session_ptr() },
}
}

View File

@ -1599,12 +1599,15 @@ impl SslContextBuilder {
///
/// # Safety
///
/// The returned `SslSession` must not be associated with a different `SslContext`.
/// The returned [`SslSession`] must not be associated with a different [`SslContext`].
///
/// [`SSL_CTX_sess_set_get_cb`]: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_sess_set_new_cb.html
pub unsafe fn set_get_session_callback<F>(&mut self, callback: F)
where
F: Fn(&mut SslRef, &[u8]) -> Option<SslSession> + 'static + Sync + Send,
F: Fn(&mut SslRef, &[u8]) -> Result<Option<SslSession>, GetSessionPendingError>
+ 'static
+ Sync
+ Send,
{
self.set_ex_data(SslContext::cached_ex_index::<F>(), callback);
ffi::SSL_CTX_sess_set_get_cb(self.as_ptr(), Some(callbacks::raw_get_session::<F>));
@ -1978,6 +1981,13 @@ impl SslContextRef {
}
}
/// Error returned by the callback to get a session when operation
/// could not complete and should be retried later.
///
/// See [`SslContextBuilder::set_get_session_callback`].
#[derive(Debug)]
pub struct GetSessionPendingError;
#[cfg(not(any(feature = "fips", feature = "fips-link-precompiled")))]
type ProtosLen = usize;
#[cfg(any(feature = "fips", feature = "fips-link-precompiled"))]

View File

@ -1,10 +1,11 @@
use std::io::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::OnceLock;
use crate::ssl::test::server::Server;
use crate::ssl::{
Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslSession, SslSessionCacheMode,
SslVersion,
ErrorCode, GetSessionPendingError, HandshakeError, Ssl, SslContext, SslContextBuilder,
SslMethod, SslOptions, SslSession, SslSessionCacheMode, SslVersion,
};
#[test]
@ -53,7 +54,7 @@ fn new_get_session_callback() {
unsafe {
server.ctx().set_get_session_callback(|_, id| {
let Some(der) = SERVER_SESSION_DER.get() else {
return None;
return Ok(None);
};
let session = SslSession::from_der(der).unwrap();
@ -62,7 +63,7 @@ fn new_get_session_callback() {
assert_eq!(id, session.id());
Some(session)
Ok(Some(session))
});
}
server.ctx().set_session_id_context(b"foo").unwrap();
@ -100,6 +101,54 @@ fn new_get_session_callback() {
assert!(FOUND_SESSION.load(Ordering::SeqCst));
}
#[test]
fn new_get_session_callback_pending() {
static CALLED_SERVER_CALLBACK: AtomicBool = AtomicBool::new(false);
let mut server = Server::builder();
server
.ctx()
.set_max_proto_version(Some(SslVersion::TLS1_2))
.unwrap();
server.ctx().set_options(SslOptions::NO_TICKET);
server
.ctx()
.set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL);
unsafe {
server.ctx().set_get_session_callback(|_, _| {
if !CALLED_SERVER_CALLBACK.swap(true, Ordering::SeqCst) {
return Err(GetSessionPendingError);
}
Ok(None)
});
}
server.ctx().set_session_id_context(b"foo").unwrap();
server.err_cb(|error| {
let HandshakeError::WouldBlock(mid_handshake) = error else {
panic!("should be WouldBlock");
};
assert!(mid_handshake.error().would_block());
assert_eq!(mid_handshake.error().code(), ErrorCode::PENDING_SESSION);
let mut socket = mid_handshake.handshake().unwrap();
socket.write_all(&[0]).unwrap();
});
let server = server.build();
let mut client = server.client();
client
.ctx()
.set_session_cache_mode(SslSessionCacheMode::CLIENT);
client.connect();
}
#[test]
fn new_session_callback_swapped_ctx() {
static CALLED_BACK: AtomicBool = AtomicBool::new(false);