diff --git a/openssl-sys/src/lib.rs b/openssl-sys/src/lib.rs index f19fa16c..03824e05 100644 --- a/openssl-sys/src/lib.rs +++ b/openssl-sys/src/lib.rs @@ -60,6 +60,9 @@ pub struct BIO_METHOD { pub callback_ctrl: Option c_long>, } +// so we can create static BIO_METHODs +unsafe impl Sync for BIO_METHOD {} + #[repr(C)] pub struct BIO { pub method: *mut BIO_METHOD, diff --git a/openssl/src/ssl/bio.rs b/openssl/src/ssl/bio.rs index 92479b37..1009c0bc 100644 --- a/openssl/src/ssl/bio.rs +++ b/openssl/src/ssl/bio.rs @@ -1,25 +1,38 @@ use libc::{c_char, c_int, c_long, c_void, strlen}; use ffi::{BIO, BIO_METHOD, BIO_CTRL_FLUSH, BIO_TYPE_NONE, BIO_new}; use ffi_extras::{BIO_clear_retry_flags, BIO_set_retry_read, BIO_set_retry_write}; -use std::any::Any; use std::io; use std::io::prelude::*; -use std::marker::PhantomData; use std::mem; use std::slice; -use std::sync::Mutex; use std::ptr; use ssl::error::SslError; +// "rust" +const NAME: [c_char; 5] = [114, 117, 115, 116, 0]; + +// we use this after removing the stream from the BIO so that we don't have to +// worry about freeing the heap allocated BIO_METHOD after freeing the BIO. +static DESTROY_METHOD: BIO_METHOD = BIO_METHOD { + type_: BIO_TYPE_NONE, + name: &NAME[0], + bwrite: None, + bread: None, + bputs: None, + bgets: None, + ctrl: None, + create: None, + destroy: Some(destroy), + callback_ctrl: None, +}; + pub struct StreamState { pub stream: S, pub error: Option, } -pub fn new_bio(stream: S) -> Result<(*mut BIO, Box), SslError> { - // "rust" - static NAME: [c_char; 5] = [114, 117, 115, 116, 0]; +pub fn new(stream: S) -> Result<(*mut BIO, Box), SslError> { let method = Box::new(BIO_METHOD { type_: BIO_TYPE_NONE, @@ -30,7 +43,7 @@ pub fn new_bio(stream: S) -> Result<(*mut BIO, Box) bgets: None, ctrl: Some(ctrl::), create: Some(create), - destroy: Some(destroy), + destroy: None, // covered in the replacement BIO_METHOD callback_ctrl: None, }); @@ -40,12 +53,9 @@ pub fn new_bio(stream: S) -> Result<(*mut BIO, Box) }); unsafe { - let bio = BIO_new(&*method); - if bio.is_null() { - return Err(SslError::get()); - } - + let bio = try_ssl_null!(BIO_new(&*method)); (*bio).ptr = Box::into_raw(state) as *mut _; + (*bio).init = 1; return Ok((bio, method)); } @@ -59,9 +69,20 @@ pub unsafe fn take_error(bio: *mut BIO) -> Option { pub unsafe fn take_stream(bio: *mut BIO) -> S { let state: Box> = Box::from_raw((*bio).ptr as *mut _); (*bio).ptr = ptr::null_mut(); + (*bio).method = &DESTROY_METHOD as *const _ as *mut _; + (*bio).init = 0; state.stream } +pub unsafe fn get_ref<'a, S: 'a>(bio: *mut BIO) -> &'a S { + let state: &'a StreamState = mem::transmute((*bio).ptr); + &state.stream +} + +pub unsafe fn get_mut<'a, S: 'a>(bio: *mut BIO) -> &'a mut S { + &mut state(bio).stream +} + unsafe fn state<'a, S: 'a>(bio: *mut BIO) -> &'a mut StreamState { mem::transmute((*bio).ptr) } @@ -106,8 +127,8 @@ unsafe extern "C" fn bputs(bio: *mut BIO, s: *const c_char) -> c_int { unsafe extern "C" fn ctrl(bio: *mut BIO, cmd: c_int, - num: c_long, - ptr: *mut c_void) + _num: c_long, + _ptr: *mut c_void) -> c_long { if cmd == BIO_CTRL_FLUSH { let state = state::(bio); diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 2f09e4fa..5d24fc32 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -18,6 +18,7 @@ use std::any::Any; use libc::{c_uchar, c_uint}; #[cfg(any(feature = "npn", feature = "alpn"))] use std::slice; +use std::marker::PhantomData; use bio::{MemBio}; use ffi; @@ -724,6 +725,10 @@ impl Ssl { Ok(ssl) } + fn get_raw_rbio(&self) -> *mut ffi::BIO { + unsafe { ffi::SSL_get_rbio(self.ssl) } + } + fn get_rbio<'a>(&'a self) -> MemBioRef<'a> { unsafe { self.wrap_bio(ffi::SSL_get_rbio(self.ssl)) } } @@ -1345,6 +1350,102 @@ impl Write for SslStream { } } +pub struct SslStreamNg { + ssl: Ssl, + _method: Box, // :( + _p: PhantomData, +} + +impl Drop for SslStreamNg { + fn drop(&mut self) { + unsafe { + let _ = bio::take_stream::(self.ssl.get_raw_rbio()); + } + } +} + +impl SslStreamNg { + fn new_base(ssl: Ssl, stream: S) -> Result { + unsafe { + let (bio, method) = try!(bio::new(stream)); + ffi::SSL_set_bio(ssl.ssl, bio, bio); + + Ok(SslStreamNg { + ssl: ssl, + _method: method, + _p: PhantomData, + }) + } + } + + /// Creates an SSL/TLS client operating over the provided stream. + pub fn connect(ssl: T, stream: S) -> Result { + let ssl = try!(ssl.into_ssl()); + let mut stream = try!(Self::new_base(ssl, stream)); + let ret = stream.ssl.connect(); + if ret > 0 { + Ok(stream) + } else { + Err(stream.make_error(ret)) + } + } + + /// Creates an SSL/TLS server operating over the provided stream. + pub fn accept(ssl: T, stream: S) -> Result { + let ssl = try!(ssl.into_ssl()); + let mut stream = try!(Self::new_base(ssl, stream)); + let ret = stream.ssl.accept(); + if ret > 0 { + Ok(stream) + } else { + Err(stream.make_error(ret)) + } + } + + pub fn get_ref(&self) -> &S { + unsafe { + let bio = self.ssl.get_raw_rbio(); + bio::get_ref(bio) + } + } + + pub fn mut_ref(&mut self) -> &mut S { + unsafe { + let bio = self.ssl.get_raw_rbio(); + bio::get_mut(bio) + } + } + + fn make_error(&mut self, ret: c_int) -> SslError { + match self.ssl.get_error(ret) { + LibSslError::ErrorSsl => SslError::get(), + LibSslError::ErrorSyscall => { + let err = SslError::get(); + let count = match err { + SslError::OpenSslErrors(ref v) => v.len(), + _ => unreachable!(), + }; + if count == 0 { + if ret == 0 { + SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, + "unexpected EOF observed")) + } else { + let error = unsafe { bio::take_error::(self.ssl.get_raw_rbio()) }; + SslError::StreamError(error.unwrap()) + } + } else { + err + } + } + LibSslError::ErrorWantWrite | LibSslError::ErrorWantRead => { + let error = unsafe { bio::take_error::(self.ssl.get_raw_rbio()) }; + SslError::StreamError(error.unwrap()) + } + err => panic!("unexpected error {:?} with ret {}", err, ret), + } + } +} + pub trait IntoSsl { fn into_ssl(self) -> Result; } diff --git a/openssl/src/ssl/tests/mod.rs b/openssl/src/ssl/tests/mod.rs index 025a45a8..d1f34019 100644 --- a/openssl/src/ssl/tests/mod.rs +++ b/openssl/src/ssl/tests/mod.rs @@ -18,7 +18,7 @@ use ssl::SSL_VERIFY_PEER; use ssl::SslMethod::Sslv23; use ssl::SslMethod; use ssl::error::NonblockingSslError; -use ssl::{SslContext, SslStream, VerifyCallback, NonblockingSslStream}; +use ssl::{SslContext, SslStream, VerifyCallback, NonblockingSslStream, SslStreamNg}; use x509::X509StoreContext; use x509::X509FileType; use x509::X509; @@ -929,3 +929,10 @@ fn test_read_nonblocking() { assert!(bytes_read >= 5); assert_eq!(&input_buffer[..5], b"HTTP/"); } + +#[test] +fn ng_connect() { + let (_s, stream) = Server::new(); + let ctx = SslContext::new(Sslv23).unwrap(); + SslStreamNg::connect(&ctx, stream).unwrap(); +}