Test set_get_session_callback

This commit is contained in:
Anthony Ramine 2023-10-24 14:23:14 +02:00 committed by Alessandro Ghedini
parent 965fde7bae
commit 1e2a4812d2
2 changed files with 74 additions and 20 deletions

View File

@ -32,6 +32,7 @@ impl Server {
io_cb: Box::new(|_| {}),
err_cb: Box::new(|_| {}),
should_error: false,
expected_connections_count: 1,
}
}
@ -61,6 +62,7 @@ pub struct Builder {
io_cb: Box<dyn FnMut(SslStream<TcpStream>) + Send>,
err_cb: Box<dyn FnMut(HandshakeError<TcpStream>) + Send>,
should_error: bool,
expected_connections_count: usize,
}
impl Builder {
@ -92,6 +94,10 @@ impl Builder {
self.should_error = true;
}
pub fn expected_connections_count(&mut self, count: usize) {
self.expected_connections_count = count;
}
pub fn build(self) -> Server {
let ctx = self.ctx.build();
let socket = TcpListener::bind("127.0.0.1:0").unwrap();
@ -100,18 +106,27 @@ impl Builder {
let mut io_cb = self.io_cb;
let mut err_cb = self.err_cb;
let should_error = self.should_error;
let mut count = self.expected_connections_count;
let handle = thread::spawn(move || {
let socket = socket.accept().unwrap().0;
let mut ssl = Ssl::new(&ctx).unwrap();
ssl_cb(&mut ssl);
let r = ssl.accept(socket);
if should_error {
err_cb(r.unwrap_err());
} else {
let mut socket = r.unwrap();
socket.write_all(&[0]).unwrap();
io_cb(socket);
while count > 0 {
let socket = socket.accept().unwrap().0;
let mut ssl = Ssl::new(&ctx).unwrap();
ssl_cb(&mut ssl);
let r = ssl.accept(socket);
if should_error {
err_cb(r.unwrap_err());
} else {
let mut socket = r.unwrap();
socket.write_all(&[0]).unwrap();
io_cb(socket);
}
count -= 1;
}
});

View File

@ -3,7 +3,8 @@ use std::sync::OnceLock;
use crate::ssl::test::server::Server;
use crate::ssl::{
Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslSessionCacheMode, SslVersion,
Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslSession, SslSessionCacheMode,
SslVersion,
};
#[test]
@ -30,11 +31,14 @@ fn active_session() {
}
#[test]
fn new_session_callback() {
static SESSION_ID: OnceLock<Vec<u8>> = OnceLock::new();
fn new_get_session_callback() {
static FOUND_SESSION: AtomicBool = AtomicBool::new(false);
static SERVER_SESSION_DER: OnceLock<Vec<u8>> = OnceLock::new();
static CLIENT_SESSION_DER: OnceLock<Vec<u8>> = OnceLock::new();
let mut server = Server::builder();
server.expected_connections_count(2);
server
.ctx()
.set_max_proto_version(Some(SslVersion::TLS1_2))
@ -42,7 +46,25 @@ fn new_session_callback() {
server.ctx().set_options(SslOptions::NO_TICKET);
server
.ctx()
.set_new_session_callback(|_, session| SESSION_ID.set(session.id().to_owned()).unwrap());
.set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL);
server.ctx().set_new_session_callback(|_, session| {
SERVER_SESSION_DER.set(session.to_der().unwrap()).unwrap()
});
unsafe {
server.ctx().set_get_session_callback(|_, id| {
let Some(der) = SERVER_SESSION_DER.get() else {
return None;
};
let session = SslSession::from_der(der).unwrap();
FOUND_SESSION.store(true, Ordering::SeqCst);
assert_eq!(id, session.id());
Some(session)
});
}
server.ctx().set_session_id_context(b"foo").unwrap();
let server = server.build();
@ -51,14 +73,31 @@ fn new_session_callback() {
client
.ctx()
.set_session_cache_mode(SslSessionCacheMode::CLIENT | SslSessionCacheMode::NO_INTERNAL);
client
.ctx()
.set_new_session_callback(|_, session| assert_eq!(SESSION_ID.get().unwrap(), session.id()));
.set_session_cache_mode(SslSessionCacheMode::CLIENT);
client.ctx().set_new_session_callback(|_, session| {
CLIENT_SESSION_DER.set(session.to_der().unwrap()).unwrap()
});
client.connect();
let client = client.build();
assert!(SESSION_ID.get().is_some());
client.builder().connect();
assert!(CLIENT_SESSION_DER.get().is_some());
assert!(SERVER_SESSION_DER.get().is_some());
assert!(!FOUND_SESSION.load(Ordering::SeqCst));
let mut ssl_builder = client.builder();
unsafe {
ssl_builder
.ssl()
.set_session(&SslSession::from_der(CLIENT_SESSION_DER.get().unwrap()).unwrap())
.unwrap();
}
ssl_builder.connect();
assert!(FOUND_SESSION.load(Ordering::SeqCst));
}
#[test]