Implement conversion of HandshakeError to the source stream
This commit is contained in:
parent
42322f8b1e
commit
f40ac2d1ed
|
|
@ -62,14 +62,27 @@ pub unsafe fn take_panic<S>(bio: *mut BIO) -> Option<Box<dyn Any + Send>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn get_ref<'a, S: 'a>(bio: *mut BIO) -> &'a S {
|
pub unsafe fn get_ref<'a, S: 'a>(bio: *mut BIO) -> &'a S {
|
||||||
let state = &*(BIO_get_data(bio) as *const StreamState<S>);
|
&state(bio).stream
|
||||||
&state.stream
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn get_mut<'a, S: 'a>(bio: *mut BIO) -> &'a mut S {
|
pub unsafe fn get_mut<'a, S: 'a>(bio: *mut BIO) -> &'a mut S {
|
||||||
&mut state(bio).stream
|
&mut state(bio).stream
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub unsafe extern "C" fn take_stream<S>(bio: *mut BIO) -> S {
|
||||||
|
assert!(!bio.is_null());
|
||||||
|
|
||||||
|
let data = BIO_get_data(bio);
|
||||||
|
|
||||||
|
assert!(!data.is_null());
|
||||||
|
|
||||||
|
let state = Box::<StreamState<S>>::from_raw(data as *mut _);
|
||||||
|
|
||||||
|
BIO_set_data(bio, ptr::null_mut());
|
||||||
|
|
||||||
|
state.stream
|
||||||
|
}
|
||||||
|
|
||||||
pub unsafe fn set_dtls_mtu_size<S>(bio: *mut BIO, mtu_size: usize) {
|
pub unsafe fn set_dtls_mtu_size<S>(bio: *mut BIO, mtu_size: usize) {
|
||||||
if mtu_size as u64 > c_long::max_value() as u64 {
|
if mtu_size as u64 > c_long::max_value() as u64 {
|
||||||
panic!(
|
panic!(
|
||||||
|
|
@ -81,7 +94,11 @@ pub unsafe fn set_dtls_mtu_size<S>(bio: *mut BIO, mtu_size: usize) {
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn state<'a, S: 'a>(bio: *mut BIO) -> &'a mut StreamState<S> {
|
unsafe fn state<'a, S: 'a>(bio: *mut BIO) -> &'a mut StreamState<S> {
|
||||||
&mut *(BIO_get_data(bio) as *mut _)
|
let data = BIO_get_data(bio) as *mut StreamState<S>;
|
||||||
|
|
||||||
|
assert!(!data.is_null());
|
||||||
|
|
||||||
|
&mut *data
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C" fn bwrite<S: Write>(bio: *mut BIO, buf: *const c_char, len: c_int) -> c_int {
|
unsafe extern "C" fn bwrite<S: Write>(bio: *mut BIO, buf: *const c_char, len: c_int) -> c_int {
|
||||||
|
|
@ -181,9 +198,12 @@ unsafe extern "C" fn destroy<S>(bio: *mut BIO) -> c_int {
|
||||||
}
|
}
|
||||||
|
|
||||||
let data = BIO_get_data(bio);
|
let data = BIO_get_data(bio);
|
||||||
assert!(!data.is_null());
|
|
||||||
|
if !data.is_null() {
|
||||||
Box::<StreamState<S>>::from_raw(data as *mut _);
|
Box::<StreamState<S>>::from_raw(data as *mut _);
|
||||||
BIO_set_data(bio, ptr::null_mut());
|
BIO_set_data(bio, ptr::null_mut());
|
||||||
|
}
|
||||||
|
|
||||||
BIO_set_init(bio, 0);
|
BIO_set_init(bio, 0);
|
||||||
1
|
1
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2650,6 +2650,11 @@ impl<S> MidHandshakeSslStream<S> {
|
||||||
self.error
|
self.error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the source data stream.
|
||||||
|
pub fn into_source_stream(self) -> S {
|
||||||
|
self.stream.into_inner()
|
||||||
|
}
|
||||||
|
|
||||||
/// Restarts the handshake process.
|
/// Restarts the handshake process.
|
||||||
///
|
///
|
||||||
/// This corresponds to [`SSL_do_handshake`].
|
/// This corresponds to [`SSL_do_handshake`].
|
||||||
|
|
@ -2856,6 +2861,11 @@ impl<S> SslStream<S> {
|
||||||
unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) }
|
unsafe { bio::take_error::<S>(self.ssl.get_raw_rbio()) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Converts the SslStream to the underlying data stream.
|
||||||
|
pub fn into_inner(self) -> S {
|
||||||
|
unsafe { bio::take_stream::<S>(self.ssl.get_raw_rbio()) }
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a shared reference to the underlying stream.
|
/// Returns a shared reference to the underlying stream.
|
||||||
pub fn get_ref(&self) -> &S {
|
pub fn get_ref(&self) -> &S {
|
||||||
unsafe {
|
unsafe {
|
||||||
|
|
|
||||||
|
|
@ -252,6 +252,14 @@ impl<S> HandshakeError<S> {
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Converts error to the source data stream tha was used for the handshake.
|
||||||
|
pub fn into_source_stream(self) -> Option<S> {
|
||||||
|
match self.0 {
|
||||||
|
ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> fmt::Debug for HandshakeError<S>
|
impl<S> fmt::Debug for HandshakeError<S>
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
use boring::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod};
|
use boring::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod};
|
||||||
use futures::future;
|
use futures::future;
|
||||||
use std::net::ToSocketAddrs;
|
use std::future::Future;
|
||||||
|
use std::net::{SocketAddr, ToSocketAddrs};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
use tokio_boring::{HandshakeError, SslStream};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn google() {
|
async fn google() {
|
||||||
|
|
@ -31,9 +33,12 @@ async fn google() {
|
||||||
assert!(response.ends_with("</html>") || response.ends_with("</HTML>"));
|
assert!(response.ends_with("</html>") || response.ends_with("</HTML>"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
fn create_server() -> (
|
||||||
async fn server() {
|
impl Future<Output = Result<SslStream<TcpStream>, HandshakeError<TcpStream>>>,
|
||||||
let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
SocketAddr,
|
||||||
|
) {
|
||||||
|
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
|
||||||
|
let mut listener = TcpListener::from_std(listener).unwrap();
|
||||||
let addr = listener.local_addr().unwrap();
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
let server = async move {
|
let server = async move {
|
||||||
|
|
@ -47,8 +52,19 @@ async fn server() {
|
||||||
let acceptor = acceptor.build();
|
let acceptor = acceptor.build();
|
||||||
|
|
||||||
let stream = listener.accept().await.unwrap().0;
|
let stream = listener.accept().await.unwrap().0;
|
||||||
let mut stream = tokio_boring::accept(&acceptor, stream).await.unwrap();
|
|
||||||
|
|
||||||
|
tokio_boring::accept(&acceptor, stream).await
|
||||||
|
};
|
||||||
|
|
||||||
|
(server, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn server() {
|
||||||
|
let (stream, addr) = create_server();
|
||||||
|
|
||||||
|
let server = async {
|
||||||
|
let mut stream = stream.await.unwrap();
|
||||||
let mut buf = [0; 4];
|
let mut buf = [0; 4];
|
||||||
stream.read_exact(&mut buf).await.unwrap();
|
stream.read_exact(&mut buf).await.unwrap();
|
||||||
assert_eq!(&buf, b"asdf");
|
assert_eq!(&buf, b"asdf");
|
||||||
|
|
@ -57,7 +73,7 @@ async fn server() {
|
||||||
|
|
||||||
future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx))
|
future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx))
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap();
|
||||||
};
|
};
|
||||||
|
|
||||||
let client = async {
|
let client = async {
|
||||||
|
|
@ -79,3 +95,28 @@ async fn server() {
|
||||||
|
|
||||||
future::join(server, client).await;
|
future::join(server, client).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handshake_error() {
|
||||||
|
let (stream, addr) = create_server();
|
||||||
|
|
||||||
|
let server = async {
|
||||||
|
let err = stream.await.unwrap_err();
|
||||||
|
|
||||||
|
assert!(err.into_source_stream().is_some());
|
||||||
|
};
|
||||||
|
|
||||||
|
let client = async {
|
||||||
|
let connector = SslConnector::builder(SslMethod::tls()).unwrap();
|
||||||
|
let config = connector.build().configure().unwrap();
|
||||||
|
let stream = TcpStream::connect(&addr).await.unwrap();
|
||||||
|
|
||||||
|
let err = tokio_boring::connect(config, "localhost", stream)
|
||||||
|
.await
|
||||||
|
.unwrap_err();
|
||||||
|
|
||||||
|
assert!(err.into_source_stream().is_some());
|
||||||
|
};
|
||||||
|
|
||||||
|
future::join(server, client).await;
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue