From 8a26577b5da70ceb8e4b4ace72e9f0c7279b1814 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Thu, 12 Oct 2023 15:31:35 +0200 Subject: [PATCH] Allow returning GetSessionPendingError from get session callbacks --- boring/src/ssl/callbacks.rs | 22 +++++++------ boring/src/ssl/mod.rs | 14 +++++++-- boring/src/ssl/test/session.rs | 57 +++++++++++++++++++++++++++++++--- 3 files changed, 78 insertions(+), 15 deletions(-) diff --git a/boring/src/ssl/callbacks.rs b/boring/src/ssl/callbacks.rs index dc9f2d53..c4bcb3f0 100644 --- a/boring/src/ssl/callbacks.rs +++ b/boring/src/ssl/callbacks.rs @@ -1,9 +1,9 @@ #![forbid(unsafe_op_in_unsafe_fn)] use super::{ - AlpnError, ClientHello, PrivateKeyMethod, PrivateKeyMethodError, SelectCertError, SniError, - Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef, - SslSignatureAlgorithm, SESSION_CTX_INDEX, + AlpnError, ClientHello, GetSessionPendingError, PrivateKeyMethod, PrivateKeyMethodError, + SelectCertError, SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, + SslSessionRef, SslSignatureAlgorithm, SESSION_CTX_INDEX, }; use crate::error::ErrorStack; use crate::ffi; @@ -13,7 +13,6 @@ use foreign_types::ForeignTypeRef; use libc::c_char; use libc::{c_int, c_uchar, c_uint, c_void}; use std::ffi::CStr; -use std::mem; use std::ptr; use std::slice; use std::str; @@ -323,7 +322,10 @@ pub(super) unsafe extern "C" fn raw_get_session( copy: *mut c_int, ) -> *mut ffi::SSL_SESSION where - F: Fn(&mut SslRef, &[u8]) -> Option + 'static + Sync + Send, + F: Fn(&mut SslRef, &[u8]) -> Result, GetSessionPendingError> + + 'static + + Sync + + Send, { // SAFETY: boring provides valid inputs. let ssl = unsafe { SslRef::from_ptr_mut(ssl) }; @@ -342,13 +344,15 @@ where let callback = unsafe { &*(callback as *const F) }; match callback(ssl, data) { - Some(session) => { - let p = session.as_ptr(); - mem::forget(session); + Ok(Some(session)) => { + let p = session.into_ptr(); + *copy = 0; + p } - None => ptr::null_mut(), + Ok(None) => ptr::null_mut(), + Err(GetSessionPendingError) => unsafe { ffi::SSL_magic_pending_session_ptr() }, } } diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 60a649e9..1420ee32 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -1599,12 +1599,15 @@ impl SslContextBuilder { /// /// # Safety /// - /// The returned `SslSession` must not be associated with a different `SslContext`. + /// The returned [`SslSession`] must not be associated with a different [`SslContext`]. /// /// [`SSL_CTX_sess_set_get_cb`]: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_sess_set_new_cb.html pub unsafe fn set_get_session_callback(&mut self, callback: F) where - F: Fn(&mut SslRef, &[u8]) -> Option + 'static + Sync + Send, + F: Fn(&mut SslRef, &[u8]) -> Result, GetSessionPendingError> + + 'static + + Sync + + Send, { self.set_ex_data(SslContext::cached_ex_index::(), callback); ffi::SSL_CTX_sess_set_get_cb(self.as_ptr(), Some(callbacks::raw_get_session::)); @@ -1978,6 +1981,13 @@ impl SslContextRef { } } +/// Error returned by the callback to get a session when operation +/// could not complete and should be retried later. +/// +/// See [`SslContextBuilder::set_get_session_callback`]. +#[derive(Debug)] +pub struct GetSessionPendingError; + #[cfg(not(any(feature = "fips", feature = "fips-link-precompiled")))] type ProtosLen = usize; #[cfg(any(feature = "fips", feature = "fips-link-precompiled"))] diff --git a/boring/src/ssl/test/session.rs b/boring/src/ssl/test/session.rs index 941f3545..23c0f4d5 100644 --- a/boring/src/ssl/test/session.rs +++ b/boring/src/ssl/test/session.rs @@ -1,10 +1,11 @@ +use std::io::Write; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::OnceLock; use crate::ssl::test::server::Server; use crate::ssl::{ - Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslSession, SslSessionCacheMode, - SslVersion, + ErrorCode, GetSessionPendingError, HandshakeError, Ssl, SslContext, SslContextBuilder, + SslMethod, SslOptions, SslSession, SslSessionCacheMode, SslVersion, }; #[test] @@ -53,7 +54,7 @@ fn new_get_session_callback() { unsafe { server.ctx().set_get_session_callback(|_, id| { let Some(der) = SERVER_SESSION_DER.get() else { - return None; + return Ok(None); }; let session = SslSession::from_der(der).unwrap(); @@ -62,7 +63,7 @@ fn new_get_session_callback() { assert_eq!(id, session.id()); - Some(session) + Ok(Some(session)) }); } server.ctx().set_session_id_context(b"foo").unwrap(); @@ -100,6 +101,54 @@ fn new_get_session_callback() { assert!(FOUND_SESSION.load(Ordering::SeqCst)); } +#[test] +fn new_get_session_callback_pending() { + static CALLED_SERVER_CALLBACK: AtomicBool = AtomicBool::new(false); + + let mut server = Server::builder(); + + server + .ctx() + .set_max_proto_version(Some(SslVersion::TLS1_2)) + .unwrap(); + server.ctx().set_options(SslOptions::NO_TICKET); + server + .ctx() + .set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL); + unsafe { + server.ctx().set_get_session_callback(|_, _| { + if !CALLED_SERVER_CALLBACK.swap(true, Ordering::SeqCst) { + return Err(GetSessionPendingError); + } + + Ok(None) + }); + } + server.ctx().set_session_id_context(b"foo").unwrap(); + server.err_cb(|error| { + let HandshakeError::WouldBlock(mid_handshake) = error else { + panic!("should be WouldBlock"); + }; + + assert!(mid_handshake.error().would_block()); + assert_eq!(mid_handshake.error().code(), ErrorCode::PENDING_SESSION); + + let mut socket = mid_handshake.handshake().unwrap(); + + socket.write_all(&[0]).unwrap(); + }); + + let server = server.build(); + + let mut client = server.client(); + + client + .ctx() + .set_session_cache_mode(SslSessionCacheMode::CLIENT); + + client.connect(); +} + #[test] fn new_session_callback_swapped_ctx() { static CALLED_BACK: AtomicBool = AtomicBool::new(false);