diff --git a/openssl/src/ssl/callbacks.rs b/openssl/src/ssl/callbacks.rs index 75d10fb6..b0251834 100644 --- a/openssl/src/ssl/callbacks.rs +++ b/openssl/src/ssl/callbacks.rs @@ -24,7 +24,7 @@ use pkey::Params; use ssl::AlpnError; #[cfg(ossl111)] use ssl::{ClientHelloResponse, ExtensionContext}; -use ssl::{SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef}; +use ssl::{SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef, SESSION_CTX_INDEX}; #[cfg(ossl111)] use x509::X509Ref; use x509::{X509StoreContext, X509StoreContextRef}; @@ -353,7 +353,8 @@ where { let ssl = SslRef::from_ptr_mut(ssl); let callback = ssl - .ssl_context() + .ex_data(*SESSION_CTX_INDEX) + .expect("BUG: session context missing") .ex_data(SslContext::cached_ex_index::()) .expect("BUG: new session callback missing") as *const F; let session = SslSession::from_ptr(session); @@ -398,7 +399,8 @@ where { let ssl = SslRef::from_ptr_mut(ssl); let callback = ssl - .ssl_context() + .ex_data(*SESSION_CTX_INDEX) + .expect("BUG: session context missing") .ex_data(SslContext::cached_ex_index::()) .expect("BUG: get session callback missing") as *const F; let data = slice::from_raw_parts(data as *const u8, len as usize); diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 29b6e360..c1b86abf 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -486,6 +486,8 @@ impl NameType { lazy_static! { static ref INDEXES: Mutex> = Mutex::new(HashMap::new()); static ref SSL_INDEXES: Mutex> = Mutex::new(HashMap::new()); + + static ref SESSION_CTX_INDEX: Index = Ssl::new_ex_index().unwrap(); } unsafe extern "C" fn free_data_box( @@ -2278,10 +2280,14 @@ impl Ssl { /// This corresponds to [`SSL_new`]. /// /// [`SSL_new`]: https://www.openssl.org/docs/man1.0.2/ssl/SSL_new.html + // FIXME should take &SslContextRef pub fn new(ctx: &SslContext) -> Result { unsafe { - let ssl = cvt_p(ffi::SSL_new(ctx.as_ptr()))?; - Ok(Ssl::from_ptr(ssl)) + let ptr = cvt_p(ffi::SSL_new(ctx.as_ptr()))?; + let mut ssl = Ssl::from_ptr(ptr); + ssl.set_ex_data(*SESSION_CTX_INDEX, ctx.clone()); + + Ok(ssl) } } diff --git a/openssl/src/ssl/test/mod.rs b/openssl/src/ssl/test/mod.rs index e0f3c565..c1f730ec 100644 --- a/openssl/src/ssl/test/mod.rs +++ b/openssl/src/ssl/test/mod.rs @@ -30,7 +30,7 @@ use ssl::{ClientHelloResponse, ExtensionContext}; use ssl::{ Error, HandshakeError, MidHandshakeSslStream, ShutdownResult, ShutdownState, Ssl, SslAcceptor, SslConnector, SslContext, SslFiletype, SslMethod, SslOptions, SslSessionCacheMode, SslStream, - SslVerifyMode, StatusType, + SslVerifyMode, StatusType, SslContextBuilder }; #[cfg(ossl102)] use x509::store::X509StoreBuilder; @@ -1024,6 +1024,34 @@ fn new_session_callback() { assert!(CALLED_BACK.load(Ordering::SeqCst)); } +#[test] +fn new_session_callback_swapped_ctx() { + static CALLED_BACK: AtomicBool = AtomicBool::new(false); + + let mut server = Server::builder(); + server.ctx().set_session_id_context(b"foo").unwrap(); + + let server = server.build(); + + let mut client = server.client(); + + client + .ctx() + .set_session_cache_mode(SslSessionCacheMode::CLIENT | SslSessionCacheMode::NO_INTERNAL); + client + .ctx() + .set_new_session_callback(|_, _| CALLED_BACK.store(true, Ordering::SeqCst)); + + let mut client = client.build().builder(); + + let ctx = SslContextBuilder::new(SslMethod::tls()).unwrap().build(); + client.ssl().set_ssl_context(&ctx).unwrap(); + + client.connect(); + + assert!(CALLED_BACK.load(Ordering::SeqCst)); +} + #[test] fn keying_export() { let listener = TcpListener::bind("127.0.0.1:0").unwrap();