From 1e2a4812d2793634cddd4ae8ec051cccbca3bbf5 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Tue, 24 Oct 2023 14:23:14 +0200 Subject: [PATCH] Test set_get_session_callback --- boring/src/ssl/test/server.rs | 35 ++++++++++++++------ boring/src/ssl/test/session.rs | 59 ++++++++++++++++++++++++++++------ 2 files changed, 74 insertions(+), 20 deletions(-) diff --git a/boring/src/ssl/test/server.rs b/boring/src/ssl/test/server.rs index 83562331..e5c0497c 100644 --- a/boring/src/ssl/test/server.rs +++ b/boring/src/ssl/test/server.rs @@ -32,6 +32,7 @@ impl Server { io_cb: Box::new(|_| {}), err_cb: Box::new(|_| {}), should_error: false, + expected_connections_count: 1, } } @@ -61,6 +62,7 @@ pub struct Builder { io_cb: Box) + Send>, err_cb: Box) + Send>, should_error: bool, + expected_connections_count: usize, } impl Builder { @@ -92,6 +94,10 @@ impl Builder { self.should_error = true; } + pub fn expected_connections_count(&mut self, count: usize) { + self.expected_connections_count = count; + } + pub fn build(self) -> Server { let ctx = self.ctx.build(); let socket = TcpListener::bind("127.0.0.1:0").unwrap(); @@ -100,18 +106,27 @@ impl Builder { let mut io_cb = self.io_cb; let mut err_cb = self.err_cb; let should_error = self.should_error; + let mut count = self.expected_connections_count; let handle = thread::spawn(move || { - let socket = socket.accept().unwrap().0; - let mut ssl = Ssl::new(&ctx).unwrap(); - ssl_cb(&mut ssl); - let r = ssl.accept(socket); - if should_error { - err_cb(r.unwrap_err()); - } else { - let mut socket = r.unwrap(); - socket.write_all(&[0]).unwrap(); - io_cb(socket); + while count > 0 { + let socket = socket.accept().unwrap().0; + let mut ssl = Ssl::new(&ctx).unwrap(); + + ssl_cb(&mut ssl); + + let r = ssl.accept(socket); + + if should_error { + err_cb(r.unwrap_err()); + } else { + let mut socket = r.unwrap(); + + socket.write_all(&[0]).unwrap(); + io_cb(socket); + } + + count -= 1; } }); diff --git a/boring/src/ssl/test/session.rs b/boring/src/ssl/test/session.rs index 2754ad3b..941f3545 100644 --- a/boring/src/ssl/test/session.rs +++ b/boring/src/ssl/test/session.rs @@ -3,7 +3,8 @@ use std::sync::OnceLock; use crate::ssl::test::server::Server; use crate::ssl::{ - Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslSessionCacheMode, SslVersion, + Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslSession, SslSessionCacheMode, + SslVersion, }; #[test] @@ -30,11 +31,14 @@ fn active_session() { } #[test] -fn new_session_callback() { - static SESSION_ID: OnceLock> = OnceLock::new(); +fn new_get_session_callback() { + static FOUND_SESSION: AtomicBool = AtomicBool::new(false); + static SERVER_SESSION_DER: OnceLock> = OnceLock::new(); + static CLIENT_SESSION_DER: OnceLock> = OnceLock::new(); let mut server = Server::builder(); + server.expected_connections_count(2); server .ctx() .set_max_proto_version(Some(SslVersion::TLS1_2)) @@ -42,7 +46,25 @@ fn new_session_callback() { server.ctx().set_options(SslOptions::NO_TICKET); server .ctx() - .set_new_session_callback(|_, session| SESSION_ID.set(session.id().to_owned()).unwrap()); + .set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL); + server.ctx().set_new_session_callback(|_, session| { + SERVER_SESSION_DER.set(session.to_der().unwrap()).unwrap() + }); + unsafe { + server.ctx().set_get_session_callback(|_, id| { + let Some(der) = SERVER_SESSION_DER.get() else { + return None; + }; + + let session = SslSession::from_der(der).unwrap(); + + FOUND_SESSION.store(true, Ordering::SeqCst); + + assert_eq!(id, session.id()); + + Some(session) + }); + } server.ctx().set_session_id_context(b"foo").unwrap(); let server = server.build(); @@ -51,14 +73,31 @@ fn new_session_callback() { client .ctx() - .set_session_cache_mode(SslSessionCacheMode::CLIENT | SslSessionCacheMode::NO_INTERNAL); - client - .ctx() - .set_new_session_callback(|_, session| assert_eq!(SESSION_ID.get().unwrap(), session.id())); + .set_session_cache_mode(SslSessionCacheMode::CLIENT); + client.ctx().set_new_session_callback(|_, session| { + CLIENT_SESSION_DER.set(session.to_der().unwrap()).unwrap() + }); - client.connect(); + let client = client.build(); - assert!(SESSION_ID.get().is_some()); + client.builder().connect(); + + assert!(CLIENT_SESSION_DER.get().is_some()); + assert!(SERVER_SESSION_DER.get().is_some()); + assert!(!FOUND_SESSION.load(Ordering::SeqCst)); + + let mut ssl_builder = client.builder(); + + unsafe { + ssl_builder + .ssl() + .set_session(&SslSession::from_der(CLIENT_SESSION_DER.get().unwrap()).unwrap()) + .unwrap(); + } + + ssl_builder.connect(); + + assert!(FOUND_SESSION.load(Ordering::SeqCst)); } #[test]