diff --git a/boring/src/ssl/bio.rs b/boring/src/ssl/bio.rs index 1fafad66..0dbb6869 100644 --- a/boring/src/ssl/bio.rs +++ b/boring/src/ssl/bio.rs @@ -62,14 +62,27 @@ pub unsafe fn take_panic(bio: *mut BIO) -> Option> { } pub unsafe fn get_ref<'a, S: 'a>(bio: *mut BIO) -> &'a S { - let state = &*(BIO_get_data(bio) as *const StreamState); - &state.stream + &state(bio).stream } pub unsafe fn get_mut<'a, S: 'a>(bio: *mut BIO) -> &'a mut S { &mut state(bio).stream } +pub unsafe extern "C" fn take_stream(bio: *mut BIO) -> S { + assert!(!bio.is_null()); + + let data = BIO_get_data(bio); + + assert!(!data.is_null()); + + let state = Box::>::from_raw(data as *mut _); + + BIO_set_data(bio, ptr::null_mut()); + + state.stream +} + pub unsafe fn set_dtls_mtu_size(bio: *mut BIO, mtu_size: usize) { if mtu_size as u64 > c_long::max_value() as u64 { panic!( @@ -81,7 +94,11 @@ pub unsafe fn set_dtls_mtu_size(bio: *mut BIO, mtu_size: usize) { } unsafe fn state<'a, S: 'a>(bio: *mut BIO) -> &'a mut StreamState { - &mut *(BIO_get_data(bio) as *mut _) + let data = BIO_get_data(bio) as *mut StreamState; + + assert!(!data.is_null()); + + &mut *data } unsafe extern "C" fn bwrite(bio: *mut BIO, buf: *const c_char, len: c_int) -> c_int { @@ -181,9 +198,12 @@ unsafe extern "C" fn destroy(bio: *mut BIO) -> c_int { } let data = BIO_get_data(bio); - assert!(!data.is_null()); - Box::>::from_raw(data as *mut _); - BIO_set_data(bio, ptr::null_mut()); + + if !data.is_null() { + Box::>::from_raw(data as *mut _); + BIO_set_data(bio, ptr::null_mut()); + } + BIO_set_init(bio, 0); 1 } diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 442f61a6..45670613 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -2650,6 +2650,11 @@ impl MidHandshakeSslStream { self.error } + /// Returns the source data stream. + pub fn into_source_stream(self) -> S { + self.stream.into_inner() + } + /// Restarts the handshake process. /// /// This corresponds to [`SSL_do_handshake`]. @@ -2856,6 +2861,11 @@ impl SslStream { unsafe { bio::take_error::(self.ssl.get_raw_rbio()) } } + /// Converts the SslStream to the underlying data stream. + pub fn into_inner(self) -> S { + unsafe { bio::take_stream::(self.ssl.get_raw_rbio()) } + } + /// Returns a shared reference to the underlying stream. pub fn get_ref(&self) -> &S { unsafe { diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index ccbf0956..dbda8be1 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -252,6 +252,14 @@ impl HandshakeError { _ => None, } } + + /// Converts error to the source data stream tha was used for the handshake. + pub fn into_source_stream(self) -> Option { + match self.0 { + ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream), + _ => None, + } + } } impl fmt::Debug for HandshakeError diff --git a/tokio-boring/tests/google.rs b/tokio-boring/tests/google.rs index 25fca764..d408d0ec 100644 --- a/tokio-boring/tests/google.rs +++ b/tokio-boring/tests/google.rs @@ -1,9 +1,11 @@ use boring::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod}; use futures::future; -use std::net::ToSocketAddrs; +use std::future::Future; +use std::net::{SocketAddr, ToSocketAddrs}; use std::pin::Pin; use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; +use tokio_boring::{HandshakeError, SslStream}; #[tokio::test] async fn google() { @@ -31,9 +33,12 @@ async fn google() { assert!(response.ends_with("") || response.ends_with("")); } -#[tokio::test] -async fn server() { - let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); +fn create_server() -> ( + impl Future, HandshakeError>>, + SocketAddr, +) { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let mut listener = TcpListener::from_std(listener).unwrap(); let addr = listener.local_addr().unwrap(); let server = async move { @@ -47,8 +52,19 @@ async fn server() { let acceptor = acceptor.build(); let stream = listener.accept().await.unwrap().0; - let mut stream = tokio_boring::accept(&acceptor, stream).await.unwrap(); + tokio_boring::accept(&acceptor, stream).await + }; + + (server, addr) +} + +#[tokio::test] +async fn server() { + let (stream, addr) = create_server(); + + let server = async { + let mut stream = stream.await.unwrap(); let mut buf = [0; 4]; stream.read_exact(&mut buf).await.unwrap(); assert_eq!(&buf, b"asdf"); @@ -57,7 +73,7 @@ async fn server() { future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx)) .await - .unwrap() + .unwrap(); }; let client = async { @@ -79,3 +95,28 @@ async fn server() { future::join(server, client).await; } + +#[tokio::test] +async fn handshake_error() { + let (stream, addr) = create_server(); + + let server = async { + let err = stream.await.unwrap_err(); + + assert!(err.into_source_stream().is_some()); + }; + + let client = async { + let connector = SslConnector::builder(SslMethod::tls()).unwrap(); + let config = connector.build().configure().unwrap(); + let stream = TcpStream::connect(&addr).await.unwrap(); + + let err = tokio_boring::connect(config, "localhost", stream) + .await + .unwrap_err(); + + assert!(err.into_source_stream().is_some()); + }; + + future::join(server, client).await; +}