From b39a712076ef388a87918d75e324c33e4e150bc1 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 18 May 2019 10:27:40 -0700 Subject: [PATCH] Fix handling of session callbacks The session context is used for session callbacks rather than the normal context, which breaks state lookup when the context has been swapped out (e.g. for SNI). Since there isn't an accessor for the session context, we just store an extra reference in the SSL's ex data. Closes #1115 --- openssl/src/ssl/callbacks.rs | 8 +++++--- openssl/src/ssl/mod.rs | 10 ++++++++-- openssl/src/ssl/test/mod.rs | 30 +++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/openssl/src/ssl/callbacks.rs b/openssl/src/ssl/callbacks.rs index 75d10fb6..b0251834 100644 --- a/openssl/src/ssl/callbacks.rs +++ b/openssl/src/ssl/callbacks.rs @@ -24,7 +24,7 @@ use pkey::Params; use ssl::AlpnError; #[cfg(ossl111)] use ssl::{ClientHelloResponse, ExtensionContext}; -use ssl::{SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef}; +use ssl::{SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef, SESSION_CTX_INDEX}; #[cfg(ossl111)] use x509::X509Ref; use x509::{X509StoreContext, X509StoreContextRef}; @@ -353,7 +353,8 @@ where { let ssl = SslRef::from_ptr_mut(ssl); let callback = ssl - .ssl_context() + .ex_data(*SESSION_CTX_INDEX) + .expect("BUG: session context missing") .ex_data(SslContext::cached_ex_index::()) .expect("BUG: new session callback missing") as *const F; let session = SslSession::from_ptr(session); @@ -398,7 +399,8 @@ where { let ssl = SslRef::from_ptr_mut(ssl); let callback = ssl - .ssl_context() + .ex_data(*SESSION_CTX_INDEX) + .expect("BUG: session context missing") .ex_data(SslContext::cached_ex_index::()) .expect("BUG: get session callback missing") as *const F; let data = slice::from_raw_parts(data as *const u8, len as usize); diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 29b6e360..c1b86abf 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -486,6 +486,8 @@ impl NameType { lazy_static! { static ref INDEXES: Mutex> = Mutex::new(HashMap::new()); static ref SSL_INDEXES: Mutex> = Mutex::new(HashMap::new()); + + static ref SESSION_CTX_INDEX: Index = Ssl::new_ex_index().unwrap(); } unsafe extern "C" fn free_data_box( @@ -2278,10 +2280,14 @@ impl Ssl { /// This corresponds to [`SSL_new`]. /// /// [`SSL_new`]: https://www.openssl.org/docs/man1.0.2/ssl/SSL_new.html + // FIXME should take &SslContextRef pub fn new(ctx: &SslContext) -> Result { unsafe { - let ssl = cvt_p(ffi::SSL_new(ctx.as_ptr()))?; - Ok(Ssl::from_ptr(ssl)) + let ptr = cvt_p(ffi::SSL_new(ctx.as_ptr()))?; + let mut ssl = Ssl::from_ptr(ptr); + ssl.set_ex_data(*SESSION_CTX_INDEX, ctx.clone()); + + Ok(ssl) } } diff --git a/openssl/src/ssl/test/mod.rs b/openssl/src/ssl/test/mod.rs index e0f3c565..c1f730ec 100644 --- a/openssl/src/ssl/test/mod.rs +++ b/openssl/src/ssl/test/mod.rs @@ -30,7 +30,7 @@ use ssl::{ClientHelloResponse, ExtensionContext}; use ssl::{ Error, HandshakeError, MidHandshakeSslStream, ShutdownResult, ShutdownState, Ssl, SslAcceptor, SslConnector, SslContext, SslFiletype, SslMethod, SslOptions, SslSessionCacheMode, SslStream, - SslVerifyMode, StatusType, + SslVerifyMode, StatusType, SslContextBuilder }; #[cfg(ossl102)] use x509::store::X509StoreBuilder; @@ -1024,6 +1024,34 @@ fn new_session_callback() { assert!(CALLED_BACK.load(Ordering::SeqCst)); } +#[test] +fn new_session_callback_swapped_ctx() { + static CALLED_BACK: AtomicBool = AtomicBool::new(false); + + let mut server = Server::builder(); + server.ctx().set_session_id_context(b"foo").unwrap(); + + let server = server.build(); + + let mut client = server.client(); + + client + .ctx() + .set_session_cache_mode(SslSessionCacheMode::CLIENT | SslSessionCacheMode::NO_INTERNAL); + client + .ctx() + .set_new_session_callback(|_, _| CALLED_BACK.store(true, Ordering::SeqCst)); + + let mut client = client.build().builder(); + + let ctx = SslContextBuilder::new(SslMethod::tls()).unwrap().build(); + client.ssl().set_ssl_context(&ctx).unwrap(); + + client.connect(); + + assert!(CALLED_BACK.load(Ordering::SeqCst)); +} + #[test] fn keying_export() { let listener = TcpListener::bind("127.0.0.1:0").unwrap();