From 0a2013a6d561b8cf8dd26da08df1d8e83a8bf2a5 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Fri, 4 Aug 2023 13:42:09 +0200 Subject: [PATCH] Introduce helper module in tokio-boring tests --- tokio-boring/tests/client_server.rs | 88 ++++---------------------- tokio-boring/tests/common/mod.rs | 96 +++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 77 deletions(-) create mode 100644 tokio-boring/tests/common/mod.rs diff --git a/tokio-boring/tests/client_server.rs b/tokio-boring/tests/client_server.rs index 72c5a040..925f9875 100644 --- a/tokio-boring/tests/client_server.rs +++ b/tokio-boring/tests/client_server.rs @@ -1,11 +1,12 @@ -use boring::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod}; +use boring::ssl::{SslConnector, SslMethod}; use futures::future; -use std::future::Future; -use std::net::{SocketAddr, ToSocketAddrs}; -use std::pin::Pin; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream}; -use tokio_boring::{HandshakeError, SslStream}; +use std::net::ToSocketAddrs; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +mod common; + +use self::common::{connect, create_server, with_trivial_client_server_exchange}; #[tokio::test] async fn google() { @@ -33,75 +34,14 @@ async fn google() { assert!(response.ends_with("") || response.ends_with("")); } -fn create_server() -> ( - impl Future, HandshakeError>>, - SocketAddr, -) { - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - - listener.set_nonblocking(true).unwrap(); - - let listener = TcpListener::from_std(listener).unwrap(); - let addr = listener.local_addr().unwrap(); - - let server = async move { - let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - acceptor - .set_private_key_file("tests/key.pem", SslFiletype::PEM) - .unwrap(); - acceptor - .set_certificate_chain_file("tests/cert.pem") - .unwrap(); - let acceptor = acceptor.build(); - - let stream = listener.accept().await.unwrap().0; - - tokio_boring::accept(&acceptor, stream).await - }; - - (server, addr) -} - #[tokio::test] async fn server() { - let (stream, addr) = create_server(); - - let server = async { - let mut stream = stream.await.unwrap(); - let mut buf = [0; 4]; - stream.read_exact(&mut buf).await.unwrap(); - assert_eq!(&buf, b"asdf"); - - stream.write_all(b"jkl;").await.unwrap(); - - future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx)) - .await - .unwrap(); - }; - - let client = async { - let mut connector = SslConnector::builder(SslMethod::tls()).unwrap(); - connector.set_ca_file("tests/cert.pem").unwrap(); - let config = connector.build().configure().unwrap(); - - let stream = TcpStream::connect(&addr).await.unwrap(); - let mut stream = tokio_boring::connect(config, "localhost", stream) - .await - .unwrap(); - - stream.write_all(b"asdf").await.unwrap(); - - let mut buf = vec![]; - stream.read_to_end(&mut buf).await.unwrap(); - assert_eq!(buf, b"jkl;"); - }; - - future::join(server, client).await; + with_trivial_client_server_exchange(|_| ()).await; } #[tokio::test] async fn handshake_error() { - let (stream, addr) = create_server(); + let (stream, addr) = create_server(|_| ()); let server = async { let err = stream.await.unwrap_err(); @@ -110,13 +50,7 @@ async fn handshake_error() { }; let client = async { - let connector = SslConnector::builder(SslMethod::tls()).unwrap(); - let config = connector.build().configure().unwrap(); - let stream = TcpStream::connect(&addr).await.unwrap(); - - let err = tokio_boring::connect(config, "localhost", stream) - .await - .unwrap_err(); + let err = connect(addr, |_| Ok(())).await.unwrap_err(); assert!(err.into_source_stream().is_some()); }; diff --git a/tokio-boring/tests/common/mod.rs b/tokio-boring/tests/common/mod.rs new file mode 100644 index 00000000..6ed394ef --- /dev/null +++ b/tokio-boring/tests/common/mod.rs @@ -0,0 +1,96 @@ +#![allow(dead_code)] + +use boring::error::ErrorStack; +use boring::ssl::{ + SslAcceptor, SslAcceptorBuilder, SslConnector, SslConnectorBuilder, SslFiletype, SslMethod, +}; +use futures::future::{self, Future}; +use std::net::SocketAddr; +use std::pin::Pin; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_boring::{HandshakeError, SslStream}; + +pub(crate) fn create_server( + setup: impl FnOnce(&mut SslAcceptorBuilder), +) -> ( + impl Future, HandshakeError>>, + SocketAddr, +) { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + + listener.set_nonblocking(true).unwrap(); + + let listener = TcpListener::from_std(listener).unwrap(); + let addr = listener.local_addr().unwrap(); + + let server = async move { + let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + + acceptor + .set_private_key_file("tests/key.pem", SslFiletype::PEM) + .unwrap(); + + acceptor + .set_certificate_chain_file("tests/cert.pem") + .unwrap(); + + setup(&mut acceptor); + + let acceptor = acceptor.build(); + + let stream = listener.accept().await.unwrap().0; + + tokio_boring::accept(&acceptor, stream).await + }; + + (server, addr) +} + +pub(crate) async fn connect( + addr: SocketAddr, + setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>, +) -> Result, HandshakeError> { + let mut connector = SslConnector::builder(SslMethod::tls()).unwrap(); + + setup(&mut connector).unwrap(); + + let config = connector.build().configure().unwrap(); + + let stream = TcpStream::connect(&addr).await.unwrap(); + + tokio_boring::connect(config, "localhost", stream).await +} + +pub(crate) async fn with_trivial_client_server_exchange( + server_setup: impl FnOnce(&mut SslAcceptorBuilder), +) { + let (stream, addr) = create_server(server_setup); + + let server = async { + let mut stream = stream.await.unwrap(); + let mut buf = [0; 4]; + stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"asdf"); + + stream.write_all(b"jkl;").await.unwrap(); + + future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx)) + .await + .unwrap(); + }; + + let client = async { + let mut stream = connect(addr, |builder| builder.set_ca_file("tests/cert.pem")) + .await + .unwrap(); + + stream.write_all(b"asdf").await.unwrap(); + + let mut buf = vec![]; + stream.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"jkl;"); + }; + + future::join(server, client).await; +}