From 1df131ff810dabd9d0452688bf8956e0af58b06a Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 12 Dec 2015 15:01:16 -0800 Subject: [PATCH] Build out a new error type --- openssl/src/ssl/error.rs | 148 +++++++++++++++++++++++++++++++++++---- openssl/src/ssl/mod.rs | 118 +++++++++++++++++++++++++------ 2 files changed, 231 insertions(+), 35 deletions(-) diff --git a/openssl/src/ssl/error.rs b/openssl/src/ssl/error.rs index d76494f1..9a1a63b2 100644 --- a/openssl/src/ssl/error.rs +++ b/openssl/src/ssl/error.rs @@ -3,12 +3,136 @@ pub use self::OpensslError::*; use libc::c_ulong; use std::error; +use std::error::Error as StdError; use std::fmt; use std::ffi::CStr; use std::io; +use std::str; use ffi; +/// An SSL error. +#[derive(Debug)] +pub enum Error { + /// The SSL session has been closed by the other end + ZeroReturn, + /// An attempt to read data from the underlying socket returned + /// `WouldBlock`. Wait for read readiness and reattempt the operation. + WantRead(io::Error), + /// An attempt to write data from the underlying socket returned + /// `WouldBlock`. Wait for write readiness and reattempt the operation. + WantWrite(io::Error), + #[doc(hidden)] // unused for now + WantX509Lookup, + /// An error reported by the underlying stream. + Stream(io::Error), + /// An error in the OpenSSL library. + Ssl(Vec), +} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + try!(fmt.write_str(self.description())); + match *self { + Error::Stream(ref err) => write!(fmt, ": {}", err), + Error::WantRead(ref err) => write!(fmt, ": {}", err), + Error::WantWrite(ref err) => write!(fmt, ": {}", err), + Error::Ssl(ref errs) => { + let mut first = true; + for err in errs { + if first { + try!(fmt.write_str(": ")); + first = false; + } else { + try!(fmt.write_str(", ")); + } + try!(fmt.write_str(&err.reason())) + } + Ok(()) + } + _ => Ok(()) + } + } +} + +impl error::Error for Error { + fn description(&self) -> &str { + match *self { + Error::ZeroReturn => "The SSL session was closed by the other end", + Error::WantRead(_) => "A read attempt returned a `WouldBlock` error", + Error::WantWrite(_) => "A write attempt returned a `WouldBlock` error", + Error::WantX509Lookup => "The client certificate callback requested to be called again", + Error::Stream(_) => "The underlying stream reported an error", + Error::Ssl(_) => "The OpenSSL library reported an error", + } + } + + fn cause(&self) -> Option<&error::Error> { + match *self { + Error::WantRead(ref err) => Some(err), + Error::WantWrite(ref err) => Some(err), + Error::Stream(ref err) => Some(err), + _ => None + } + } +} + +/// An error reported from OpenSSL. +pub struct OpenSslError(c_ulong); + +impl OpenSslError { + /// Returns the contents of the OpenSSL error stack. + pub fn get_stack() -> Vec { + ffi::init(); + + let mut errs = vec!(); + loop { + match unsafe { ffi::ERR_get_error() } { + 0 => break, + err => errs.push(OpenSslError(err)) + } + } + errs + } + + /// Returns the name of the library reporting the error. + pub fn library(&self) -> &'static str { + get_lib(self.0) + } + + /// Returns the name of the function reporting the error. + pub fn function(&self) -> &'static str { + get_func(self.0) + } + + /// Returns the reason for the error. + pub fn reason(&self) -> &'static str { + get_reason(self.0) + } +} + +impl fmt::Debug for OpenSslError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("OpenSslError") + .field("library", &self.library()) + .field("function", &self.function()) + .field("reason", &self.reason()) + .finish() + } +} + +impl fmt::Display for OpenSslError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str(&self.reason()) + } +} + +impl error::Error for OpenSslError { + fn description(&self) -> &str { + "An OpenSSL error" + } +} + /// An SSL error #[derive(Debug)] pub enum SslError { @@ -115,27 +239,27 @@ pub enum OpensslError { } } -fn get_lib(err: c_ulong) -> String { +fn get_lib(err: c_ulong) -> &'static str { unsafe { let cstr = ffi::ERR_lib_error_string(err); - let bytes = CStr::from_ptr(cstr as *const _).to_bytes().to_vec(); - String::from_utf8(bytes).unwrap() + let bytes = CStr::from_ptr(cstr as *const _).to_bytes(); + str::from_utf8(bytes).unwrap() } } -fn get_func(err: c_ulong) -> String { +fn get_func(err: c_ulong) -> &'static str { unsafe { let cstr = ffi::ERR_func_error_string(err); - let bytes = CStr::from_ptr(cstr as *const _).to_bytes().to_vec(); - String::from_utf8(bytes).unwrap() + let bytes = CStr::from_ptr(cstr as *const _).to_bytes(); + str::from_utf8(bytes).unwrap() } } -fn get_reason(err: c_ulong) -> String { +fn get_reason(err: c_ulong) -> &'static str { unsafe { let cstr = ffi::ERR_reason_error_string(err); - let bytes = CStr::from_ptr(cstr as *const _).to_bytes().to_vec(); - String::from_utf8(bytes).unwrap() + let bytes = CStr::from_ptr(cstr as *const _).to_bytes(); + str::from_utf8(bytes).unwrap() } } @@ -161,9 +285,9 @@ impl SslError { fn from_error_code(err: c_ulong) -> OpensslError { ffi::init(); UnknownError { - library: get_lib(err), - function: get_func(err), - reason: get_reason(err) + library: get_lib(err).to_owned(), + function: get_func(err).to_owned(), + reason: get_reason(err).to_owned() } } } diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index bc64d7c8..89c8bbfc 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -22,7 +22,7 @@ use std::marker::PhantomData; use ffi; use ffi_extras; use dh::DH; -use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors}; +use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError}; use x509::{X509StoreContext, X509FileType, X509}; use crypto::pkey::PKey; @@ -31,6 +31,9 @@ mod bio; #[cfg(test)] mod tests; +#[doc(inline)] +pub use ssl::error::Error; + extern "C" { fn rust_SSL_clone(ssl: *mut ffi::SSL); } @@ -954,7 +957,17 @@ impl SslStream { if ret > 0 { Ok(stream) } else { - Err(stream.make_error(ret)) + match stream.make_old_error(ret) { + SslError::StreamError(e) => { + // This is fine - nonblocking sockets will finish the handshake in read/write + if e.kind() == io::ErrorKind::WouldBlock { + Ok(stream) + } else { + Err(SslError::StreamError(e)) + } + } + e => Err(e) + } } } @@ -966,7 +979,17 @@ impl SslStream { if ret > 0 { Ok(stream) } else { - Err(stream.make_error(ret)) + match stream.make_old_error(ret) { + SslError::StreamError(e) => { + // This is fine - nonblocking sockets will finish the handshake in read/write + if e.kind() == io::ErrorKind::WouldBlock { + Ok(stream) + } else { + Err(SslError::StreamError(e)) + } + } + e => Err(e) + } } } @@ -986,7 +1009,31 @@ impl SslStream { } impl SslStream { - fn make_error(&mut self, ret: c_int) -> SslError { + fn make_error(&mut self, ret: c_int) -> Error { + match self.ssl.get_error(ret) { + LibSslError::ErrorSsl => Error::Ssl(OpenSslError::get_stack()), + LibSslError::ErrorSyscall => { + let errs = OpenSslError::get_stack(); + if errs.is_empty() { + if ret == 0 { + Error::Stream(io::Error::new(io::ErrorKind::ConnectionAborted, + "unexpected EOF observed")) + } else { + Error::Stream(self.get_bio_error()) + } + } else { + Error::Ssl(errs) + } + } + LibSslError::ErrorZeroReturn => Error::ZeroReturn, + LibSslError::ErrorWantWrite => Error::WantWrite(self.get_bio_error()), + LibSslError::ErrorWantRead => Error::WantRead(self.get_bio_error()), + err => Error::Stream(io::Error::new(io::ErrorKind::Other, + format!("unexpected error {:?}", err))), + } + } + + fn make_old_error(&mut self, ret: c_int) -> SslError { match self.ssl.get_error(ret) { LibSslError::ErrorSsl => SslError::get(), LibSslError::ErrorSyscall => { @@ -1045,6 +1092,32 @@ impl SslStream { } } + /// Like `read`, but returns an `ssl::Error` rather than an `io::Error`. + /// + /// This is particularly useful with a nonblocking socket, where the error + /// value will identify if OpenSSL is waiting on read or write readiness. + pub fn ssl_read(&mut self, buf: &mut [u8]) -> Result { + let ret = self.ssl.read(buf); + if ret >= 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } + + /// Like `write`, but returns an `ssl::Error` rather than an `io::Error`. + /// + /// This is particularly useful with a nonblocking socket, where the error + /// value will identify if OpenSSL is waiting on read or write readiness. + pub fn ssl_write(&mut self, buf: &[u8]) -> Result { + let ret = self.ssl.write(buf); + if ret >= 0 { + Ok(ret as usize) + } else { + Err(self.make_error(ret)) + } + } + /// Returns the OpenSSL `Ssl` object associated with this stream. pub fn ssl(&self) -> &Ssl { &self.ssl @@ -1061,30 +1134,27 @@ impl SslStream<::std::net::TcpStream> { impl Read for SslStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let ret = self.ssl.read(buf); - if ret >= 0 { - return Ok(ret as usize); - } - - match self.make_error(ret) { - SslError::SslSessionClosed => Ok(0), - SslError::StreamError(e) => Err(e), - e => Err(io::Error::new(io::ErrorKind::Other, e)), + match self.ssl_read(buf) { + Ok(n) => Ok(n), + Err(Error::ZeroReturn) => Ok(0), + Err(Error::Stream(e)) => Err(e), + Err(Error::WantRead(e)) => Err(e), + Err(Error::WantWrite(e)) => Err(e), + Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)), } } } impl Write for SslStream { fn write(&mut self, buf: &[u8]) -> io::Result { - let ret = self.ssl.write(buf); - if ret > 0 { - return Ok(ret as usize); - } - - match self.make_error(ret) { - SslError::StreamError(e) => Err(e), - e => Err(io::Error::new(io::ErrorKind::Other, e)), - } + self.ssl_write(buf).map_err(|e| { + match e { + Error::Stream(e) => e, + Error::WantRead(e) => e, + Error::WantWrite(e) => e, + e => io::Error::new(io::ErrorKind::Other, e), + } + }) } fn flush(&mut self) -> io::Result<()> { @@ -1174,7 +1244,9 @@ impl MaybeSslStream { } } -/// An SSL stream wrapping a nonblocking socket. +/// # Deprecated +/// +/// Use `SslStream` with `ssl_read` and `ssl_write`. #[derive(Clone)] pub struct NonblockingSslStream { stream: S,