Allow returning GetSessionPendingError from get session callbacks

This commit is contained in:
Anthony Ramine 2023-10-12 15:31:35 +02:00 committed by Alessandro Ghedini
parent 1e2a4812d2
commit 8a26577b5d
3 changed files with 78 additions and 15 deletions

View File

@ -1,9 +1,9 @@
#![forbid(unsafe_op_in_unsafe_fn)] #![forbid(unsafe_op_in_unsafe_fn)]
use super::{ use super::{
AlpnError, ClientHello, PrivateKeyMethod, PrivateKeyMethodError, SelectCertError, SniError, AlpnError, ClientHello, GetSessionPendingError, PrivateKeyMethod, PrivateKeyMethodError,
Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef, SelectCertError, SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession,
SslSignatureAlgorithm, SESSION_CTX_INDEX, SslSessionRef, SslSignatureAlgorithm, SESSION_CTX_INDEX,
}; };
use crate::error::ErrorStack; use crate::error::ErrorStack;
use crate::ffi; use crate::ffi;
@ -13,7 +13,6 @@ use foreign_types::ForeignTypeRef;
use libc::c_char; use libc::c_char;
use libc::{c_int, c_uchar, c_uint, c_void}; use libc::{c_int, c_uchar, c_uint, c_void};
use std::ffi::CStr; use std::ffi::CStr;
use std::mem;
use std::ptr; use std::ptr;
use std::slice; use std::slice;
use std::str; use std::str;
@ -323,7 +322,10 @@ pub(super) unsafe extern "C" fn raw_get_session<F>(
copy: *mut c_int, copy: *mut c_int,
) -> *mut ffi::SSL_SESSION ) -> *mut ffi::SSL_SESSION
where where
F: Fn(&mut SslRef, &[u8]) -> Option<SslSession> + 'static + Sync + Send, F: Fn(&mut SslRef, &[u8]) -> Result<Option<SslSession>, GetSessionPendingError>
+ 'static
+ Sync
+ Send,
{ {
// SAFETY: boring provides valid inputs. // SAFETY: boring provides valid inputs.
let ssl = unsafe { SslRef::from_ptr_mut(ssl) }; let ssl = unsafe { SslRef::from_ptr_mut(ssl) };
@ -342,13 +344,15 @@ where
let callback = unsafe { &*(callback as *const F) }; let callback = unsafe { &*(callback as *const F) };
match callback(ssl, data) { match callback(ssl, data) {
Some(session) => { Ok(Some(session)) => {
let p = session.as_ptr(); let p = session.into_ptr();
mem::forget(session);
*copy = 0; *copy = 0;
p p
} }
None => ptr::null_mut(), Ok(None) => ptr::null_mut(),
Err(GetSessionPendingError) => unsafe { ffi::SSL_magic_pending_session_ptr() },
} }
} }

View File

@ -1599,12 +1599,15 @@ impl SslContextBuilder {
/// ///
/// # Safety /// # Safety
/// ///
/// The returned `SslSession` must not be associated with a different `SslContext`. /// The returned [`SslSession`] must not be associated with a different [`SslContext`].
/// ///
/// [`SSL_CTX_sess_set_get_cb`]: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_sess_set_new_cb.html /// [`SSL_CTX_sess_set_get_cb`]: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_sess_set_new_cb.html
pub unsafe fn set_get_session_callback<F>(&mut self, callback: F) pub unsafe fn set_get_session_callback<F>(&mut self, callback: F)
where where
F: Fn(&mut SslRef, &[u8]) -> Option<SslSession> + 'static + Sync + Send, F: Fn(&mut SslRef, &[u8]) -> Result<Option<SslSession>, GetSessionPendingError>
+ 'static
+ Sync
+ Send,
{ {
self.set_ex_data(SslContext::cached_ex_index::<F>(), callback); self.set_ex_data(SslContext::cached_ex_index::<F>(), callback);
ffi::SSL_CTX_sess_set_get_cb(self.as_ptr(), Some(callbacks::raw_get_session::<F>)); ffi::SSL_CTX_sess_set_get_cb(self.as_ptr(), Some(callbacks::raw_get_session::<F>));
@ -1978,6 +1981,13 @@ impl SslContextRef {
} }
} }
/// Error returned by the callback to get a session when operation
/// could not complete and should be retried later.
///
/// See [`SslContextBuilder::set_get_session_callback`].
#[derive(Debug)]
pub struct GetSessionPendingError;
#[cfg(not(any(feature = "fips", feature = "fips-link-precompiled")))] #[cfg(not(any(feature = "fips", feature = "fips-link-precompiled")))]
type ProtosLen = usize; type ProtosLen = usize;
#[cfg(any(feature = "fips", feature = "fips-link-precompiled"))] #[cfg(any(feature = "fips", feature = "fips-link-precompiled"))]

View File

@ -1,10 +1,11 @@
use std::io::Write;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::OnceLock; use std::sync::OnceLock;
use crate::ssl::test::server::Server; use crate::ssl::test::server::Server;
use crate::ssl::{ use crate::ssl::{
Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslSession, SslSessionCacheMode, ErrorCode, GetSessionPendingError, HandshakeError, Ssl, SslContext, SslContextBuilder,
SslVersion, SslMethod, SslOptions, SslSession, SslSessionCacheMode, SslVersion,
}; };
#[test] #[test]
@ -53,7 +54,7 @@ fn new_get_session_callback() {
unsafe { unsafe {
server.ctx().set_get_session_callback(|_, id| { server.ctx().set_get_session_callback(|_, id| {
let Some(der) = SERVER_SESSION_DER.get() else { let Some(der) = SERVER_SESSION_DER.get() else {
return None; return Ok(None);
}; };
let session = SslSession::from_der(der).unwrap(); let session = SslSession::from_der(der).unwrap();
@ -62,7 +63,7 @@ fn new_get_session_callback() {
assert_eq!(id, session.id()); assert_eq!(id, session.id());
Some(session) Ok(Some(session))
}); });
} }
server.ctx().set_session_id_context(b"foo").unwrap(); server.ctx().set_session_id_context(b"foo").unwrap();
@ -100,6 +101,54 @@ fn new_get_session_callback() {
assert!(FOUND_SESSION.load(Ordering::SeqCst)); assert!(FOUND_SESSION.load(Ordering::SeqCst));
} }
#[test]
fn new_get_session_callback_pending() {
static CALLED_SERVER_CALLBACK: AtomicBool = AtomicBool::new(false);
let mut server = Server::builder();
server
.ctx()
.set_max_proto_version(Some(SslVersion::TLS1_2))
.unwrap();
server.ctx().set_options(SslOptions::NO_TICKET);
server
.ctx()
.set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL);
unsafe {
server.ctx().set_get_session_callback(|_, _| {
if !CALLED_SERVER_CALLBACK.swap(true, Ordering::SeqCst) {
return Err(GetSessionPendingError);
}
Ok(None)
});
}
server.ctx().set_session_id_context(b"foo").unwrap();
server.err_cb(|error| {
let HandshakeError::WouldBlock(mid_handshake) = error else {
panic!("should be WouldBlock");
};
assert!(mid_handshake.error().would_block());
assert_eq!(mid_handshake.error().code(), ErrorCode::PENDING_SESSION);
let mut socket = mid_handshake.handshake().unwrap();
socket.write_all(&[0]).unwrap();
});
let server = server.build();
let mut client = server.client();
client
.ctx()
.set_session_cache_mode(SslSessionCacheMode::CLIENT);
client.connect();
}
#[test] #[test]
fn new_session_callback_swapped_ctx() { fn new_session_callback_swapped_ctx() {
static CALLED_BACK: AtomicBool = AtomicBool::new(false); static CALLED_BACK: AtomicBool = AtomicBool::new(false);