Implement conversion of HandshakeError to the source stream

This commit is contained in:
Ivan Nikulin 2020-12-23 12:25:30 +00:00
parent 42322f8b1e
commit f40ac2d1ed
4 changed files with 91 additions and 12 deletions

View File

@ -62,14 +62,27 @@ pub unsafe fn take_panic<S>(bio: *mut BIO) -> Option<Box<dyn Any + Send>> {
}
pub unsafe fn get_ref<'a, S: 'a>(bio: *mut BIO) -> &'a S {
let state = &*(BIO_get_data(bio) as *const StreamState<S>);
&state.stream
&state(bio).stream
}
pub unsafe fn get_mut<'a, S: 'a>(bio: *mut BIO) -> &'a mut S {
&mut state(bio).stream
}
pub unsafe extern "C" fn take_stream<S>(bio: *mut BIO) -> S {
assert!(!bio.is_null());
let data = BIO_get_data(bio);
assert!(!data.is_null());
let state = Box::<StreamState<S>>::from_raw(data as *mut _);
BIO_set_data(bio, ptr::null_mut());
state.stream
}
pub unsafe fn set_dtls_mtu_size<S>(bio: *mut BIO, mtu_size: usize) {
if mtu_size as u64 > c_long::max_value() as u64 {
panic!(
@ -81,7 +94,11 @@ pub unsafe fn set_dtls_mtu_size<S>(bio: *mut BIO, mtu_size: usize) {
}
unsafe fn state<'a, S: 'a>(bio: *mut BIO) -> &'a mut StreamState<S> {
&mut *(BIO_get_data(bio) as *mut _)
let data = BIO_get_data(bio) as *mut StreamState<S>;
assert!(!data.is_null());
&mut *data
}
unsafe extern "C" fn bwrite<S: Write>(bio: *mut BIO, buf: *const c_char, len: c_int) -> c_int {
@ -181,9 +198,12 @@ unsafe extern "C" fn destroy<S>(bio: *mut BIO) -> c_int {
}
let data = BIO_get_data(bio);
assert!(!data.is_null());
if !data.is_null() {
Box::<StreamState<S>>::from_raw(data as *mut _);
BIO_set_data(bio, ptr::null_mut());
}
BIO_set_init(bio, 0);
1
}

View File

@ -2650,6 +2650,11 @@ impl<S> MidHandshakeSslStream<S> {
self.error
}
/// Returns the source data stream.
pub fn into_source_stream(self) -> S {
self.stream.into_inner()
}
/// Restarts the handshake process.
///
/// This corresponds to [`SSL_do_handshake`].
@ -2856,6 +2861,11 @@ impl<S> SslStream<S> {
unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) }
}
/// Converts the SslStream to the underlying data stream.
pub fn into_inner(self) -> S {
unsafe { bio::take_stream::<S>(self.ssl.get_raw_rbio()) }
}
/// Returns a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
unsafe {

View File

@ -252,6 +252,14 @@ impl<S> HandshakeError<S> {
_ => None,
}
}
/// Converts error to the source data stream tha was used for the handshake.
pub fn into_source_stream(self) -> Option<S> {
match self.0 {
ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
_ => None,
}
}
}
impl<S> fmt::Debug for HandshakeError<S>

View File

@ -1,9 +1,11 @@
use boring::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod};
use futures::future;
use std::net::ToSocketAddrs;
use std::future::Future;
use std::net::{SocketAddr, ToSocketAddrs};
use std::pin::Pin;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_boring::{HandshakeError, SslStream};
#[tokio::test]
async fn google() {
@ -31,9 +33,12 @@ async fn google() {
assert!(response.ends_with("</html>") || response.ends_with("</HTML>"));
}
#[tokio::test]
async fn server() {
let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
fn create_server() -> (
impl Future<Output = Result<SslStream<TcpStream>, HandshakeError<TcpStream>>>,
SocketAddr,
) {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let mut listener = TcpListener::from_std(listener).unwrap();
let addr = listener.local_addr().unwrap();
let server = async move {
@ -47,8 +52,19 @@ async fn server() {
let acceptor = acceptor.build();
let stream = listener.accept().await.unwrap().0;
let mut stream = tokio_boring::accept(&acceptor, stream).await.unwrap();
tokio_boring::accept(&acceptor, stream).await
};
(server, addr)
}
#[tokio::test]
async fn server() {
let (stream, addr) = create_server();
let server = async {
let mut stream = stream.await.unwrap();
let mut buf = [0; 4];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"asdf");
@ -57,7 +73,7 @@ async fn server() {
future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx))
.await
.unwrap()
.unwrap();
};
let client = async {
@ -79,3 +95,28 @@ async fn server() {
future::join(server, client).await;
}
#[tokio::test]
async fn handshake_error() {
let (stream, addr) = create_server();
let server = async {
let err = stream.await.unwrap_err();
assert!(err.into_source_stream().is_some());
};
let client = async {
let connector = SslConnector::builder(SslMethod::tls()).unwrap();
let config = connector.build().configure().unwrap();
let stream = TcpStream::connect(&addr).await.unwrap();
let err = tokio_boring::connect(config, "localhost", stream)
.await
.unwrap_err();
assert!(err.into_source_stream().is_some());
};
future::join(server, client).await;
}