Allow returning GetSessionPendingError from get session callbacks
This commit is contained in:
parent
1e2a4812d2
commit
8a26577b5d
|
|
@ -1,9 +1,9 @@
|
|||
#![forbid(unsafe_op_in_unsafe_fn)]
|
||||
|
||||
use super::{
|
||||
AlpnError, ClientHello, PrivateKeyMethod, PrivateKeyMethodError, SelectCertError, SniError,
|
||||
Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef,
|
||||
SslSignatureAlgorithm, SESSION_CTX_INDEX,
|
||||
AlpnError, ClientHello, GetSessionPendingError, PrivateKeyMethod, PrivateKeyMethodError,
|
||||
SelectCertError, SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession,
|
||||
SslSessionRef, SslSignatureAlgorithm, SESSION_CTX_INDEX,
|
||||
};
|
||||
use crate::error::ErrorStack;
|
||||
use crate::ffi;
|
||||
|
|
@ -13,7 +13,6 @@ use foreign_types::ForeignTypeRef;
|
|||
use libc::c_char;
|
||||
use libc::{c_int, c_uchar, c_uint, c_void};
|
||||
use std::ffi::CStr;
|
||||
use std::mem;
|
||||
use std::ptr;
|
||||
use std::slice;
|
||||
use std::str;
|
||||
|
|
@ -323,7 +322,10 @@ pub(super) unsafe extern "C" fn raw_get_session<F>(
|
|||
copy: *mut c_int,
|
||||
) -> *mut ffi::SSL_SESSION
|
||||
where
|
||||
F: Fn(&mut SslRef, &[u8]) -> Option<SslSession> + 'static + Sync + Send,
|
||||
F: Fn(&mut SslRef, &[u8]) -> Result<Option<SslSession>, GetSessionPendingError>
|
||||
+ 'static
|
||||
+ Sync
|
||||
+ Send,
|
||||
{
|
||||
// SAFETY: boring provides valid inputs.
|
||||
let ssl = unsafe { SslRef::from_ptr_mut(ssl) };
|
||||
|
|
@ -342,13 +344,15 @@ where
|
|||
let callback = unsafe { &*(callback as *const F) };
|
||||
|
||||
match callback(ssl, data) {
|
||||
Some(session) => {
|
||||
let p = session.as_ptr();
|
||||
mem::forget(session);
|
||||
Ok(Some(session)) => {
|
||||
let p = session.into_ptr();
|
||||
|
||||
*copy = 0;
|
||||
|
||||
p
|
||||
}
|
||||
None => ptr::null_mut(),
|
||||
Ok(None) => ptr::null_mut(),
|
||||
Err(GetSessionPendingError) => unsafe { ffi::SSL_magic_pending_session_ptr() },
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1599,12 +1599,15 @@ impl SslContextBuilder {
|
|||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The returned `SslSession` must not be associated with a different `SslContext`.
|
||||
/// The returned [`SslSession`] must not be associated with a different [`SslContext`].
|
||||
///
|
||||
/// [`SSL_CTX_sess_set_get_cb`]: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_sess_set_new_cb.html
|
||||
pub unsafe fn set_get_session_callback<F>(&mut self, callback: F)
|
||||
where
|
||||
F: Fn(&mut SslRef, &[u8]) -> Option<SslSession> + 'static + Sync + Send,
|
||||
F: Fn(&mut SslRef, &[u8]) -> Result<Option<SslSession>, GetSessionPendingError>
|
||||
+ 'static
|
||||
+ Sync
|
||||
+ Send,
|
||||
{
|
||||
self.set_ex_data(SslContext::cached_ex_index::<F>(), callback);
|
||||
ffi::SSL_CTX_sess_set_get_cb(self.as_ptr(), Some(callbacks::raw_get_session::<F>));
|
||||
|
|
@ -1978,6 +1981,13 @@ impl SslContextRef {
|
|||
}
|
||||
}
|
||||
|
||||
/// Error returned by the callback to get a session when operation
|
||||
/// could not complete and should be retried later.
|
||||
///
|
||||
/// See [`SslContextBuilder::set_get_session_callback`].
|
||||
#[derive(Debug)]
|
||||
pub struct GetSessionPendingError;
|
||||
|
||||
#[cfg(not(any(feature = "fips", feature = "fips-link-precompiled")))]
|
||||
type ProtosLen = usize;
|
||||
#[cfg(any(feature = "fips", feature = "fips-link-precompiled"))]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
use std::io::Write;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use crate::ssl::test::server::Server;
|
||||
use crate::ssl::{
|
||||
Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslSession, SslSessionCacheMode,
|
||||
SslVersion,
|
||||
ErrorCode, GetSessionPendingError, HandshakeError, Ssl, SslContext, SslContextBuilder,
|
||||
SslMethod, SslOptions, SslSession, SslSessionCacheMode, SslVersion,
|
||||
};
|
||||
|
||||
#[test]
|
||||
|
|
@ -53,7 +54,7 @@ fn new_get_session_callback() {
|
|||
unsafe {
|
||||
server.ctx().set_get_session_callback(|_, id| {
|
||||
let Some(der) = SERVER_SESSION_DER.get() else {
|
||||
return None;
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let session = SslSession::from_der(der).unwrap();
|
||||
|
|
@ -62,7 +63,7 @@ fn new_get_session_callback() {
|
|||
|
||||
assert_eq!(id, session.id());
|
||||
|
||||
Some(session)
|
||||
Ok(Some(session))
|
||||
});
|
||||
}
|
||||
server.ctx().set_session_id_context(b"foo").unwrap();
|
||||
|
|
@ -100,6 +101,54 @@ fn new_get_session_callback() {
|
|||
assert!(FOUND_SESSION.load(Ordering::SeqCst));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_get_session_callback_pending() {
|
||||
static CALLED_SERVER_CALLBACK: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
let mut server = Server::builder();
|
||||
|
||||
server
|
||||
.ctx()
|
||||
.set_max_proto_version(Some(SslVersion::TLS1_2))
|
||||
.unwrap();
|
||||
server.ctx().set_options(SslOptions::NO_TICKET);
|
||||
server
|
||||
.ctx()
|
||||
.set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL);
|
||||
unsafe {
|
||||
server.ctx().set_get_session_callback(|_, _| {
|
||||
if !CALLED_SERVER_CALLBACK.swap(true, Ordering::SeqCst) {
|
||||
return Err(GetSessionPendingError);
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
});
|
||||
}
|
||||
server.ctx().set_session_id_context(b"foo").unwrap();
|
||||
server.err_cb(|error| {
|
||||
let HandshakeError::WouldBlock(mid_handshake) = error else {
|
||||
panic!("should be WouldBlock");
|
||||
};
|
||||
|
||||
assert!(mid_handshake.error().would_block());
|
||||
assert_eq!(mid_handshake.error().code(), ErrorCode::PENDING_SESSION);
|
||||
|
||||
let mut socket = mid_handshake.handshake().unwrap();
|
||||
|
||||
socket.write_all(&[0]).unwrap();
|
||||
});
|
||||
|
||||
let server = server.build();
|
||||
|
||||
let mut client = server.client();
|
||||
|
||||
client
|
||||
.ctx()
|
||||
.set_session_cache_mode(SslSessionCacheMode::CLIENT);
|
||||
|
||||
client.connect();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_session_callback_swapped_ctx() {
|
||||
static CALLED_BACK: AtomicBool = AtomicBool::new(false);
|
||||
|
|
|
|||
Loading…
Reference in New Issue