From d6ce9afdf31faacaf435380feffcd13bf387255a Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 12 Dec 2015 15:46:17 -0800 Subject: [PATCH] Have NonblockingSslStream delegate to SslStream --- openssl/src/ssl/error.rs | 29 +++-- openssl/src/ssl/mod.rs | 249 +++++++++++++-------------------------- 2 files changed, 102 insertions(+), 176 deletions(-) diff --git a/openssl/src/ssl/error.rs b/openssl/src/ssl/error.rs index 9a1a63b2..52ea6693 100644 --- a/openssl/src/ssl/error.rs +++ b/openssl/src/ssl/error.rs @@ -95,6 +95,11 @@ impl OpenSslError { errs } + /// Returns the raw OpenSSL error code for this error. + pub fn error_code(&self) -> c_ulong { + self.0 + } + /// Returns the name of the library reporting the error. pub fn library(&self) -> &'static str { get_lib(self.0) @@ -239,6 +244,17 @@ pub enum OpensslError { } } +impl OpensslError { + pub fn from_error_code(err: c_ulong) -> OpensslError { + ffi::init(); + UnknownError { + library: get_lib(err).to_owned(), + function: get_func(err).to_owned(), + reason: get_reason(err).to_owned() + } + } +} + fn get_lib(err: c_ulong) -> &'static str { unsafe { let cstr = ffi::ERR_lib_error_string(err); @@ -271,7 +287,7 @@ impl SslError { loop { match unsafe { ffi::ERR_get_error() } { 0 => break, - err => errs.push(SslError::from_error_code(err)) + err => errs.push(OpensslError::from_error_code(err)) } } OpenSslErrors(errs) @@ -279,16 +295,7 @@ impl SslError { /// Creates an `SslError` from the raw numeric error code. pub fn from_error(err: c_ulong) -> SslError { - OpenSslErrors(vec![SslError::from_error_code(err)]) - } - - fn from_error_code(err: c_ulong) -> OpensslError { - ffi::init(); - UnknownError { - library: get_lib(err).to_owned(), - function: get_func(err).to_owned(), - reason: get_reason(err).to_owned() - } + OpenSslErrors(vec![OpensslError::from_error_code(err)]) } } diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 89c8bbfc..0ffa1120 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -10,7 +10,7 @@ use std::str; use std::net; use std::path::Path; use std::ptr; -use std::sync::{Once, ONCE_INIT, Arc, Mutex}; +use std::sync::{Once, ONCE_INIT, Mutex}; use std::cmp; use std::any::Any; #[cfg(any(feature = "npn", feature = "alpn"))] @@ -18,11 +18,16 @@ use libc::{c_uchar, c_uint}; #[cfg(any(feature = "npn", feature = "alpn"))] use std::slice; use std::marker::PhantomData; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; use ffi; use ffi_extras; use dh::DH; -use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError}; +use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError, + OpensslError}; use x509::{X509StoreContext, X509FileType, X509}; use crypto::pkey::PKey; @@ -935,6 +940,20 @@ impl fmt::Debug for SslStream where S: fmt::Debug { } } +#[cfg(unix)] +impl AsRawFd for SslStream { + fn as_raw_fd(&self) -> RawFd { + self.get_ref().as_raw_fd() + } +} + +#[cfg(windows)] +impl AsRawSocket for NonblockingSslStream { + fn as_raw_fd(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + impl SslStream { fn new_base(ssl: Ssl, stream: S) -> Self { unsafe { @@ -1247,65 +1266,38 @@ impl MaybeSslStream { /// # Deprecated /// /// Use `SslStream` with `ssl_read` and `ssl_write`. -#[derive(Clone)] -pub struct NonblockingSslStream { - stream: S, - ssl: Arc, +pub struct NonblockingSslStream(SslStream); + +impl Clone for NonblockingSslStream { + fn clone(&self) -> Self { + NonblockingSslStream(self.0.clone()) + } +} + +#[cfg(unix)] +impl AsRawFd for NonblockingSslStream { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() + } +} + +#[cfg(windows)] +impl AsRawSocket for NonblockingSslStream { + fn as_raw_fd(&self) -> RawSocket { + self.0.as_raw_socket() + } } impl NonblockingSslStream { pub fn try_clone(&self) -> io::Result> { - Ok(NonblockingSslStream { - stream: try!(self.stream.try_clone()), - ssl: self.ssl.clone(), - }) + self.0.try_clone().map(NonblockingSslStream) } } impl NonblockingSslStream { - fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result, SslError> { - unsafe { - let bio = try_ssl_null!(ffi::BIO_new_socket(sock, 0)); - ffi_extras::BIO_set_nbio(bio, 1); - ffi::SSL_set_bio(ssl.ssl, bio, bio); - } - - Ok(NonblockingSslStream { - stream: stream, - ssl: Arc::new(ssl), - }) - } - - fn make_error(&self, ret: c_int) -> NonblockingSslError { - match self.ssl.get_error(ret) { - LibSslError::ErrorSsl => NonblockingSslError::SslError(SslError::get()), - LibSslError::ErrorSyscall => { - let err = SslError::get(); - let count = match err { - SslError::OpenSslErrors(ref v) => v.len(), - _ => unreachable!(), - }; - let ssl_error = if count == 0 { - if ret == 0 { - SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, - "unexpected EOF observed")) - } else { - SslError::StreamError(io::Error::last_os_error()) - } - } else { - err - }; - ssl_error.into() - }, - LibSslError::ErrorWantWrite => NonblockingSslError::WantWrite, - LibSslError::ErrorWantRead => NonblockingSslError::WantRead, - err => panic!("unexpected error {:?} with ret {}", err, ret), - } - } - /// Returns a reference to the underlying stream. pub fn get_ref(&self) -> &S { - &self.stream + self.0.get_ref() } /// Returns a mutable reference to the underlying stream. @@ -1315,117 +1307,50 @@ impl NonblockingSslStream { /// It is inadvisable to read from or write to the underlying stream as it /// will most likely corrupt the SSL session. pub fn get_mut(&mut self) -> &mut S { - &mut self.stream + self.0.get_mut() } /// Returns a reference to the Ssl. pub fn ssl(&self) -> &Ssl { - &self.ssl - } -} - -#[cfg(unix)] -impl NonblockingSslStream { - /// Create a new nonblocking client ssl connection on wrapped `stream`. - /// - /// Note that this method will most likely not actually complete the SSL - /// handshake because doing so requires several round trips; the handshake will - /// be completed in subsequent read/write calls managed by your event loop. - pub fn connect(ssl: T, stream: S) -> Result, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_fd() as c_int; - let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); - let ret = ssl.ssl.connect(); - if ret > 0 { - Ok(ssl) - } else { - // WantRead/WantWrite is okay here; we'll finish the handshake in - // subsequent send/recv calls. - match ssl.make_error(ret) { - NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), - NonblockingSslError::SslError(other) => Err(other), - } - } - } - - /// Create a new nonblocking server ssl connection on wrapped `stream`. - /// - /// Note that this method will most likely not actually complete the SSL - /// handshake because doing so requires several round trips; the handshake will - /// be completed in subsequent read/write calls managed by your event loop. - pub fn accept(ssl: T, stream: S) -> Result, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_fd() as c_int; - let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); - let ret = ssl.ssl.accept(); - if ret > 0 { - Ok(ssl) - } else { - // WantRead/WantWrite is okay here; we'll finish the handshake in - // subsequent send/recv calls. - match ssl.make_error(ret) { - NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), - NonblockingSslError::SslError(other) => Err(other), - } - } - } -} - -#[cfg(unix)] -impl ::std::os::unix::io::AsRawFd for NonblockingSslStream { - fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd { - self.stream.as_raw_fd() - } -} - -#[cfg(windows)] -impl NonblockingSslStream { - /// Create a new nonblocking client ssl connection on wrapped `stream`. - /// - /// Note that this method will most likely not actually complete the SSL - /// handshake because doing so requires several round trips; the handshake will - /// be completed in subsequent read/write calls managed by your event loop. - pub fn connect(ssl: T, stream: S) -> Result, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_socket() as c_int; - let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); - let ret = ssl.ssl.connect(); - if ret > 0 { - Ok(ssl) - } else { - // WantRead/WantWrite is okay here; we'll finish the handshake in - // subsequent send/recv calls. - match ssl.make_error(ret) { - NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), - NonblockingSslError::SslError(other) => Err(other), - } - } - } - - /// Create a new nonblocking server ssl connection on wrapped `stream`. - /// - /// Note that this method will most likely not actually complete the SSL - /// handshake because doing so requires several round trips; the handshake will - /// be completed in subsequent read/write calls managed by your event loop. - pub fn accept(ssl: T, stream: S) -> Result, SslError> { - let ssl = try!(ssl.into_ssl()); - let fd = stream.as_raw_socket() as c_int; - let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd)); - let ret = ssl.ssl.accept(); - if ret > 0 { - Ok(ssl) - } else { - // WantRead/WantWrite is okay here; we'll finish the handshake in - // subsequent send/recv calls. - match ssl.make_error(ret) { - NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl), - NonblockingSslError::SslError(other) => Err(other), - } - } + self.0.ssl() } } impl NonblockingSslStream { + /// Create a new nonblocking client ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn connect(ssl: T, stream: S) -> Result, SslError> { + SslStream::connect(ssl, stream).map(NonblockingSslStream) + } + + /// Create a new nonblocking server ssl connection on wrapped `stream`. + /// + /// Note that this method will most likely not actually complete the SSL + /// handshake because doing so requires several round trips; the handshake will + /// be completed in subsequent read/write calls managed by your event loop. + pub fn accept(ssl: T, stream: S) -> Result, SslError> { + SslStream::accept(ssl, stream).map(NonblockingSslStream) + } + + fn convert_err(&self, err: Error) -> NonblockingSslError { + match err { + Error::ZeroReturn => SslError::SslSessionClosed.into(), + Error::WantRead(_) => NonblockingSslError::WantRead, + Error::WantWrite(_) => NonblockingSslError::WantWrite, + Error::WantX509Lookup => unreachable!(), + Error::Stream(e) => SslError::StreamError(e).into(), + Error::Ssl(e) => { + SslError::OpenSslErrors(e.iter() + .map(|e| OpensslError::from_error_code(e.error_code())) + .collect()) + .into() + } + } + } + /// Read bytes from the SSL stream into `buf`. /// /// Given the SSL state machine, this method may return either `WantWrite` @@ -1442,11 +1367,10 @@ impl NonblockingSslStream { /// On a return value of `Ok(count)`, count is the number of decrypted /// plaintext bytes copied into the `buf` slice. pub fn read(&mut self, buf: &mut [u8]) -> Result { - let ret = self.ssl.read(buf); - if ret >= 0 { - Ok(ret as usize) - } else { - Err(self.make_error(ret)) + match self.0.ssl_read(buf) { + Ok(n) => Ok(n), + Err(Error::ZeroReturn) => Ok(0), + Err(e) => Err(self.convert_err(e)) } } @@ -1466,11 +1390,6 @@ impl NonblockingSslStream { /// Given a return value of `Ok(count)`, count is the number of plaintext bytes /// from the `buf` slice that were encrypted and written onto the stream. pub fn write(&mut self, buf: &[u8]) -> Result { - let ret = self.ssl.write(buf); - if ret > 0 { - Ok(ret as usize) - } else { - Err(self.make_error(ret)) - } + self.0.ssl_write(buf).map_err(|e| self.convert_err(e)) } }