Merge pull request #5 from cloudflare/err-src-stream

Implement conversion of HandshakeError to the source stream
This commit is contained in:
Ivan Nikulin 2020-12-23 14:11:43 +00:00 committed by GitHub
commit f809be1a90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 91 additions and 12 deletions

View File

@ -62,14 +62,27 @@ pub unsafe fn take_panic<S>(bio: *mut BIO) -> Option<Box<dyn Any + Send>> {
} }
pub unsafe fn get_ref<'a, S: 'a>(bio: *mut BIO) -> &'a S { pub unsafe fn get_ref<'a, S: 'a>(bio: *mut BIO) -> &'a S {
let state = &*(BIO_get_data(bio) as *const StreamState<S>); &state(bio).stream
&state.stream
} }
pub unsafe fn get_mut<'a, S: 'a>(bio: *mut BIO) -> &'a mut S { pub unsafe fn get_mut<'a, S: 'a>(bio: *mut BIO) -> &'a mut S {
&mut state(bio).stream &mut state(bio).stream
} }
pub unsafe extern "C" fn take_stream<S>(bio: *mut BIO) -> S {
assert!(!bio.is_null());
let data = BIO_get_data(bio);
assert!(!data.is_null());
let state = Box::<StreamState<S>>::from_raw(data as *mut _);
BIO_set_data(bio, ptr::null_mut());
state.stream
}
pub unsafe fn set_dtls_mtu_size<S>(bio: *mut BIO, mtu_size: usize) { pub unsafe fn set_dtls_mtu_size<S>(bio: *mut BIO, mtu_size: usize) {
if mtu_size as u64 > c_long::max_value() as u64 { if mtu_size as u64 > c_long::max_value() as u64 {
panic!( panic!(
@ -81,7 +94,11 @@ pub unsafe fn set_dtls_mtu_size<S>(bio: *mut BIO, mtu_size: usize) {
} }
unsafe fn state<'a, S: 'a>(bio: *mut BIO) -> &'a mut StreamState<S> { unsafe fn state<'a, S: 'a>(bio: *mut BIO) -> &'a mut StreamState<S> {
&mut *(BIO_get_data(bio) as *mut _) let data = BIO_get_data(bio) as *mut StreamState<S>;
assert!(!data.is_null());
&mut *data
} }
unsafe extern "C" fn bwrite<S: Write>(bio: *mut BIO, buf: *const c_char, len: c_int) -> c_int { unsafe extern "C" fn bwrite<S: Write>(bio: *mut BIO, buf: *const c_char, len: c_int) -> c_int {
@ -181,9 +198,12 @@ unsafe extern "C" fn destroy<S>(bio: *mut BIO) -> c_int {
} }
let data = BIO_get_data(bio); let data = BIO_get_data(bio);
assert!(!data.is_null());
if !data.is_null() {
Box::<StreamState<S>>::from_raw(data as *mut _); Box::<StreamState<S>>::from_raw(data as *mut _);
BIO_set_data(bio, ptr::null_mut()); BIO_set_data(bio, ptr::null_mut());
}
BIO_set_init(bio, 0); BIO_set_init(bio, 0);
1 1
} }

View File

@ -2650,6 +2650,11 @@ impl<S> MidHandshakeSslStream<S> {
self.error self.error
} }
/// Returns the source data stream.
pub fn into_source_stream(self) -> S {
self.stream.into_inner()
}
/// Restarts the handshake process. /// Restarts the handshake process.
/// ///
/// This corresponds to [`SSL_do_handshake`]. /// This corresponds to [`SSL_do_handshake`].
@ -2856,6 +2861,11 @@ impl<S> SslStream<S> {
unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) } unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) }
} }
/// Converts the SslStream to the underlying data stream.
pub fn into_inner(self) -> S {
unsafe { bio::take_stream::<S>(self.ssl.get_raw_rbio()) }
}
/// Returns a shared reference to the underlying stream. /// Returns a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S { pub fn get_ref(&self) -> &S {
unsafe { unsafe {

View File

@ -252,6 +252,14 @@ impl<S> HandshakeError<S> {
_ => None, _ => None,
} }
} }
/// Converts error to the source data stream tha was used for the handshake.
pub fn into_source_stream(self) -> Option<S> {
match self.0 {
ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
_ => None,
}
}
} }
impl<S> fmt::Debug for HandshakeError<S> impl<S> fmt::Debug for HandshakeError<S>

View File

@ -1,9 +1,11 @@
use boring::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod}; use boring::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod};
use futures::future; use futures::future;
use std::net::ToSocketAddrs; use std::future::Future;
use std::net::{SocketAddr, ToSocketAddrs};
use std::pin::Pin; use std::pin::Pin;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio_boring::{HandshakeError, SslStream};
#[tokio::test] #[tokio::test]
async fn google() { async fn google() {
@ -31,9 +33,12 @@ async fn google() {
assert!(response.ends_with("</html>") || response.ends_with("</HTML>")); assert!(response.ends_with("</html>") || response.ends_with("</HTML>"));
} }
#[tokio::test] fn create_server() -> (
async fn server() { impl Future<Output = Result<SslStream<TcpStream>, HandshakeError<TcpStream>>>,
let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); SocketAddr,
) {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let mut listener = TcpListener::from_std(listener).unwrap();
let addr = listener.local_addr().unwrap(); let addr = listener.local_addr().unwrap();
let server = async move { let server = async move {
@ -47,8 +52,19 @@ async fn server() {
let acceptor = acceptor.build(); let acceptor = acceptor.build();
let stream = listener.accept().await.unwrap().0; let stream = listener.accept().await.unwrap().0;
let mut stream = tokio_boring::accept(&acceptor, stream).await.unwrap();
tokio_boring::accept(&acceptor, stream).await
};
(server, addr)
}
#[tokio::test]
async fn server() {
let (stream, addr) = create_server();
let server = async {
let mut stream = stream.await.unwrap();
let mut buf = [0; 4]; let mut buf = [0; 4];
stream.read_exact(&mut buf).await.unwrap(); stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"asdf"); assert_eq!(&buf, b"asdf");
@ -57,7 +73,7 @@ async fn server() {
future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx)) future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx))
.await .await
.unwrap() .unwrap();
}; };
let client = async { let client = async {
@ -79,3 +95,28 @@ async fn server() {
future::join(server, client).await; future::join(server, client).await;
} }
#[tokio::test]
async fn handshake_error() {
let (stream, addr) = create_server();
let server = async {
let err = stream.await.unwrap_err();
assert!(err.into_source_stream().is_some());
};
let client = async {
let connector = SslConnector::builder(SslMethod::tls()).unwrap();
let config = connector.build().configure().unwrap();
let stream = TcpStream::connect(&addr).await.unwrap();
let err = tokio_boring::connect(config, "localhost", stream)
.await
.unwrap_err();
assert!(err.into_source_stream().is_some());
};
future::join(server, client).await;
}