Test set_get_session_callback

This commit is contained in:
Anthony Ramine 2023-10-24 14:23:14 +02:00 committed by Alessandro Ghedini
parent 965fde7bae
commit 1e2a4812d2
2 changed files with 74 additions and 20 deletions

View File

@ -32,6 +32,7 @@ impl Server {
io_cb: Box::new(|_| {}), io_cb: Box::new(|_| {}),
err_cb: Box::new(|_| {}), err_cb: Box::new(|_| {}),
should_error: false, should_error: false,
expected_connections_count: 1,
} }
} }
@ -61,6 +62,7 @@ pub struct Builder {
io_cb: Box<dyn FnMut(SslStream<TcpStream>) + Send>, io_cb: Box<dyn FnMut(SslStream<TcpStream>) + Send>,
err_cb: Box<dyn FnMut(HandshakeError<TcpStream>) + Send>, err_cb: Box<dyn FnMut(HandshakeError<TcpStream>) + Send>,
should_error: bool, should_error: bool,
expected_connections_count: usize,
} }
impl Builder { impl Builder {
@ -92,6 +94,10 @@ impl Builder {
self.should_error = true; self.should_error = true;
} }
pub fn expected_connections_count(&mut self, count: usize) {
self.expected_connections_count = count;
}
pub fn build(self) -> Server { pub fn build(self) -> Server {
let ctx = self.ctx.build(); let ctx = self.ctx.build();
let socket = TcpListener::bind("127.0.0.1:0").unwrap(); let socket = TcpListener::bind("127.0.0.1:0").unwrap();
@ -100,19 +106,28 @@ impl Builder {
let mut io_cb = self.io_cb; let mut io_cb = self.io_cb;
let mut err_cb = self.err_cb; let mut err_cb = self.err_cb;
let should_error = self.should_error; let should_error = self.should_error;
let mut count = self.expected_connections_count;
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
while count > 0 {
let socket = socket.accept().unwrap().0; let socket = socket.accept().unwrap().0;
let mut ssl = Ssl::new(&ctx).unwrap(); let mut ssl = Ssl::new(&ctx).unwrap();
ssl_cb(&mut ssl); ssl_cb(&mut ssl);
let r = ssl.accept(socket); let r = ssl.accept(socket);
if should_error { if should_error {
err_cb(r.unwrap_err()); err_cb(r.unwrap_err());
} else { } else {
let mut socket = r.unwrap(); let mut socket = r.unwrap();
socket.write_all(&[0]).unwrap(); socket.write_all(&[0]).unwrap();
io_cb(socket); io_cb(socket);
} }
count -= 1;
}
}); });
Server { Server {

View File

@ -3,7 +3,8 @@ use std::sync::OnceLock;
use crate::ssl::test::server::Server; use crate::ssl::test::server::Server;
use crate::ssl::{ use crate::ssl::{
Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslSessionCacheMode, SslVersion, Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslSession, SslSessionCacheMode,
SslVersion,
}; };
#[test] #[test]
@ -30,11 +31,14 @@ fn active_session() {
} }
#[test] #[test]
fn new_session_callback() { fn new_get_session_callback() {
static SESSION_ID: OnceLock<Vec<u8>> = OnceLock::new(); static FOUND_SESSION: AtomicBool = AtomicBool::new(false);
static SERVER_SESSION_DER: OnceLock<Vec<u8>> = OnceLock::new();
static CLIENT_SESSION_DER: OnceLock<Vec<u8>> = OnceLock::new();
let mut server = Server::builder(); let mut server = Server::builder();
server.expected_connections_count(2);
server server
.ctx() .ctx()
.set_max_proto_version(Some(SslVersion::TLS1_2)) .set_max_proto_version(Some(SslVersion::TLS1_2))
@ -42,7 +46,25 @@ fn new_session_callback() {
server.ctx().set_options(SslOptions::NO_TICKET); server.ctx().set_options(SslOptions::NO_TICKET);
server server
.ctx() .ctx()
.set_new_session_callback(|_, session| SESSION_ID.set(session.id().to_owned()).unwrap()); .set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL);
server.ctx().set_new_session_callback(|_, session| {
SERVER_SESSION_DER.set(session.to_der().unwrap()).unwrap()
});
unsafe {
server.ctx().set_get_session_callback(|_, id| {
let Some(der) = SERVER_SESSION_DER.get() else {
return None;
};
let session = SslSession::from_der(der).unwrap();
FOUND_SESSION.store(true, Ordering::SeqCst);
assert_eq!(id, session.id());
Some(session)
});
}
server.ctx().set_session_id_context(b"foo").unwrap(); server.ctx().set_session_id_context(b"foo").unwrap();
let server = server.build(); let server = server.build();
@ -51,14 +73,31 @@ fn new_session_callback() {
client client
.ctx() .ctx()
.set_session_cache_mode(SslSessionCacheMode::CLIENT | SslSessionCacheMode::NO_INTERNAL); .set_session_cache_mode(SslSessionCacheMode::CLIENT);
client client.ctx().set_new_session_callback(|_, session| {
.ctx() CLIENT_SESSION_DER.set(session.to_der().unwrap()).unwrap()
.set_new_session_callback(|_, session| assert_eq!(SESSION_ID.get().unwrap(), session.id())); });
client.connect(); let client = client.build();
assert!(SESSION_ID.get().is_some()); client.builder().connect();
assert!(CLIENT_SESSION_DER.get().is_some());
assert!(SERVER_SESSION_DER.get().is_some());
assert!(!FOUND_SESSION.load(Ordering::SeqCst));
let mut ssl_builder = client.builder();
unsafe {
ssl_builder
.ssl()
.set_session(&SslSession::from_der(CLIENT_SESSION_DER.get().unwrap()).unwrap())
.unwrap();
}
ssl_builder.connect();
assert!(FOUND_SESSION.load(Ordering::SeqCst));
} }
#[test] #[test]