From a1a3219483977bef2c059b1681932d95eb20180e Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 4 Nov 2017 13:32:18 -0700 Subject: [PATCH] Handle local retries OpenSSL can return SSL_ERROR_WANT_READ even on blocking sockets after renegotiation or heartbeats. Heartbeats ignore the flag that normally makes these things handled internally anyway on 1.0.2. To handle this more properly, we now have a special error type we use to signal this event. The `Read` and `Write` implementation automatically retry in this situation since that's what you normally want. People can use `ssl_read` and `ssl_write` if they want the lower level control. Closes #760 --- openssl/src/ssl/error.rs | 27 +++++++ openssl/src/ssl/mod.rs | 167 +++++++++++++++++---------------------- 2 files changed, 100 insertions(+), 94 deletions(-) diff --git a/openssl/src/ssl/error.rs b/openssl/src/ssl/error.rs index db78e2c8..2244fd7f 100644 --- a/openssl/src/ssl/error.rs +++ b/openssl/src/ssl/error.rs @@ -66,6 +66,33 @@ impl From for Error { } } +/// An error indicating that the operation can be immediately retried. +/// +/// OpenSSL's [`SSL_read`] and [`SSL_write`] functions can return `SSL_ERROR_WANT_READ` even when +/// the underlying socket is performing blocking IO in certain cases. When this happens, the +/// the operation can be immediately retried. +/// +/// To signal this event, the `io::Error` inside of [`Error::WantRead`] will be constructed around +/// a `RetryError`. +/// +/// [`SSL_read`]: https://www.openssl.org/docs/manmaster/man3/SSL_read.html +/// [`SSL_write`]: https://www.openssl.org/docs/manmaster/man3/SSL_write.html +/// [`Error::WantRead`]: enum.Error.html#variant.WantRead +#[derive(Debug)] +pub struct RetryError; + +impl fmt::Display for RetryError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str(error::Error::description(self)) + } +} + +impl error::Error for RetryError { + fn description(&self) -> &str { + "operation must be retried" + } +} + /// An error or intermediate state after a TLS handshake attempt. #[derive(Debug)] pub enum HandshakeError { diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 5a924a64..62b617a9 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -72,7 +72,7 @@ //! ``` use ffi; use foreign_types::{ForeignType, ForeignTypeRef, Opaque}; -use libc::{c_int, c_void, c_long, c_ulong}; +use libc::{c_int, c_long, c_ulong, c_void}; use libc::{c_uchar, c_uint}; use std::any::Any; use std::any::TypeId; @@ -93,12 +93,12 @@ use std::slice; use std::str; use std::sync::Mutex; -use {init, cvt, cvt_p, cvt_n}; +use {cvt, cvt_n, cvt_p, init}; use dh::{Dh, DhRef}; use ec::EcKeyRef; #[cfg(any(all(feature = "v101", ossl101), all(feature = "v102", ossl102)))] use ec::EcKey; -use x509::{X509StoreContextRef, X509FileType, X509, X509Ref, X509VerifyError, X509Name}; +use x509::{X509, X509FileType, X509Name, X509Ref, X509StoreContextRef, X509VerifyError}; use x509::store::{X509StoreBuilderRef, X509StoreRef}; #[cfg(any(all(feature = "v102", ossl102), all(feature = "v110", ossl110)))] use x509::store::X509Store; @@ -111,9 +111,9 @@ use stack::{Stack, StackRef}; use ssl::bio::BioMethod; use ssl::callbacks::*; -pub use ssl::connector::{SslConnectorBuilder, SslConnector, SslAcceptorBuilder, SslAcceptor, - ConnectConfiguration}; -pub use ssl::error::{Error, HandshakeError}; +pub use ssl::connector::{ConnectConfiguration, SslAcceptor, SslAcceptorBuilder, SslConnector, + SslConnectorBuilder}; +pub use ssl::error::{Error, HandshakeError, RetryError}; mod error; mod callbacks; @@ -416,10 +416,8 @@ impl SslContextBuilder { pub fn set_verify_cert_store(&mut self, cert_store: X509Store) -> Result<(), ErrorStack> { unsafe { let ptr = cert_store.as_ptr(); - cvt( - ffi::SSL_CTX_set0_verify_cert_store(self.as_ptr(), ptr) as - c_int, - )?; + cvt(ffi::SSL_CTX_set0_verify_cert_store(self.as_ptr(), ptr) + as c_int)?; mem::forget(cert_store); Ok(()) @@ -461,8 +459,9 @@ impl SslContextBuilder { pub fn set_tmp_ecdh(&mut self, key: &EcKeyRef) -> Result<(), ErrorStack> { unsafe { - cvt(ffi::SSL_CTX_set_tmp_ecdh(self.as_ptr(), key.as_ptr()) as - c_int).map(|_| ()) + cvt(ffi::SSL_CTX_set_tmp_ecdh(self.as_ptr(), key.as_ptr()) + as c_int) + .map(|_| ()) } } @@ -579,10 +578,7 @@ impl SslContextBuilder { /// `set_certificate` to a trusted root. pub fn add_extra_chain_cert(&mut self, cert: X509) -> Result<(), ErrorStack> { unsafe { - cvt(ffi::SSL_CTX_add_extra_chain_cert( - self.as_ptr(), - cert.as_ptr(), - ) as c_int)?; + cvt(ffi::SSL_CTX_add_extra_chain_cert(self.as_ptr(), cert.as_ptr()) as c_int)?; mem::forget(cert); Ok(()) } @@ -766,10 +762,9 @@ impl SslContextBuilder { Box::into_raw(callback) as *mut c_void, ); let f: unsafe extern "C" fn(_, _) -> _ = raw_tlsext_status::; - cvt(ffi::SSL_CTX_set_tlsext_status_cb( - self.as_ptr(), - Some(f), - ) as c_int).map(|_| ()) + cvt(ffi::SSL_CTX_set_tlsext_status_cb(self.as_ptr(), Some(f)) + as c_int) + .map(|_| ()) } } @@ -781,7 +776,8 @@ impl SslContextBuilder { #[cfg(not(osslconf = "OPENSSL_NO_PSK"))] pub fn set_psk_callback(&mut self, callback: F) where - F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8], &mut [u8]) -> Result + F: Fn(&mut SslRef, Option<&[u8]>, &mut [u8], &mut [u8]) + -> Result + Any + 'static + Sync @@ -1239,10 +1235,8 @@ impl SslRef { pub fn set_hostname(&mut self, hostname: &str) -> Result<(), ErrorStack> { let cstr = CString::new(hostname).unwrap(); unsafe { - cvt(ffi::SSL_set_tlsext_host_name( - self.as_ptr(), - cstr.as_ptr() as *mut _, - ) as c_int).map(|_| ()) + cvt(ffi::SSL_set_tlsext_host_name(self.as_ptr(), cstr.as_ptr() as *mut _) as c_int) + .map(|_| ()) } } @@ -1373,9 +1367,7 @@ impl SslRef { return None; } let meth = ffi::SSL_COMP_get_name(ptr); - Some( - str::from_utf8(CStr::from_ptr(meth as *const _).to_bytes()).unwrap(), - ) + Some(str::from_utf8(CStr::from_ptr(meth as *const _).to_bytes()).unwrap()) } } @@ -1392,9 +1384,7 @@ impl SslRef { return None; } - Some( - str::from_utf8(CStr::from_ptr(name as *const _).to_bytes()).unwrap(), - ) + Some(str::from_utf8(CStr::from_ptr(name as *const _).to_bytes()).unwrap()) } } @@ -1459,10 +1449,7 @@ impl SslRef { /// Sets the status response a client wishes the server to reply with. pub fn set_status_type(&mut self, type_: StatusType) -> Result<(), ErrorStack> { unsafe { - cvt(ffi::SSL_set_tlsext_status_type( - self.as_ptr(), - type_.as_raw(), - ) as c_int).map(|_| ()) + cvt(ffi::SSL_set_tlsext_status_type(self.as_ptr(), type_.as_raw()) as c_int).map(|_| ()) } } @@ -1494,7 +1481,8 @@ impl SslRef { self.as_ptr(), p as *mut c_uchar, response.len() as c_long, - ) as c_int).map(|_| ()) + ) as c_int) + .map(|_| ()) } } @@ -1561,19 +1549,16 @@ impl Ssl { Ok(stream) } else { match stream.make_error(ret) { - e @ Error::WantWrite(_) | - e @ Error::WantRead(_) => { + e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { Err(HandshakeError::Interrupted(MidHandshakeSslStream { stream: stream, error: e, })) } - err => { - Err(HandshakeError::Failure(MidHandshakeSslStream { - stream: stream, - error: err, - })) - } + err => Err(HandshakeError::Failure(MidHandshakeSslStream { + stream: stream, + error: err, + })), } } } @@ -1594,19 +1579,16 @@ impl Ssl { Ok(stream) } else { match stream.make_error(ret) { - e @ Error::WantWrite(_) | - e @ Error::WantRead(_) => { + e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { Err(HandshakeError::Interrupted(MidHandshakeSslStream { stream: stream, error: e, })) } - err => { - Err(HandshakeError::Failure(MidHandshakeSslStream { - stream: stream, - error: err, - })) - } + err => Err(HandshakeError::Failure(MidHandshakeSslStream { + stream: stream, + error: err, + })), } } } @@ -1652,8 +1634,7 @@ impl MidHandshakeSslStream { Ok(self.stream) } else { match self.stream.make_error(ret) { - e @ Error::WantWrite(_) | - e @ Error::WantRead(_) => { + e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { self.error = e; Err(HandshakeError::Interrupted(self)) } @@ -1775,12 +1756,10 @@ impl SslStream { if errs.errors().is_empty() { match self.get_bio_error() { Some(err) => Error::Stream(err), - None => { - Error::Stream(io::Error::new( - io::ErrorKind::ConnectionAborted, - "unexpected EOF observed", - )) - } + None => Error::Stream(io::Error::new( + io::ErrorKind::ConnectionAborted, + "unexpected EOF observed", + )), } } else { Error::Ssl(errs) @@ -1790,33 +1769,24 @@ impl SslStream { ffi::SSL_ERROR_WANT_WRITE => { let err = match self.get_bio_error() { Some(err) => err, - None => { - io::Error::new( - io::ErrorKind::Other, - "BUG: got an SSL_ERROR_WANT_WRITE with no error in the BIO", - ) - } + None => io::Error::new( + io::ErrorKind::Other, + "BUG: got an SSL_ERROR_WANT_WRITE with no error in the BIO", + ), }; Error::WantWrite(err) } ffi::SSL_ERROR_WANT_READ => { let err = match self.get_bio_error() { Some(err) => err, - None => { - io::Error::new( - io::ErrorKind::Other, - "BUG: got an SSL_ERROR_WANT_READ with no error in the BIO", - ) - } + None => io::Error::new(io::ErrorKind::Other, RetryError), }; Error::WantRead(err) } - err => { - Error::Stream(io::Error::new( - io::ErrorKind::InvalidData, - format!("unexpected error {}", err), - )) - } + err => Error::Stream(io::Error::new( + io::ErrorKind::InvalidData, + format!("unexpected error {}", err), + )), } } @@ -1859,25 +1829,34 @@ impl SslStream { impl Read for SslStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.ssl_read(buf) { - Ok(n) => Ok(n), - Err(Error::ZeroReturn) => Ok(0), - Err(Error::Stream(e)) => Err(e), - Err(Error::WantRead(e)) => Err(e), - Err(Error::WantWrite(e)) => Err(e), - Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)), + loop { + match self.ssl_read(buf) { + Ok(n) => return Ok(n), + Err(Error::ZeroReturn) => return Ok(0), + Err(Error::WantRead(ref e)) + if e.get_ref().map_or(false, |e| e.is::()) => {} + Err(Error::Stream(e)) | Err(Error::WantRead(e)) | Err(Error::WantWrite(e)) => { + return Err(e); + } + Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), + } } } } impl Write for SslStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.ssl_write(buf).map_err(|e| match e { - Error::Stream(e) => e, - Error::WantRead(e) => e, - Error::WantWrite(e) => e, - e => io::Error::new(io::ErrorKind::Other, e), - }) + loop { + match self.ssl_write(buf) { + Ok(n) => return Ok(n), + Err(Error::WantRead(ref e)) + if e.get_ref().map_or(false, |e| e.is::()) => {} + Err(Error::Stream(e)) | Err(Error::WantRead(e)) | Err(Error::WantWrite(e)) => { + return Err(e); + } + Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), + } + } } fn flush(&mut self) -> io::Result<()> { @@ -1902,8 +1881,8 @@ mod compat { use ffi; use libc::c_int; - pub use ffi::{SSL_CTX_get_options, SSL_CTX_set_options, SSL_CTX_clear_options, SSL_CTX_up_ref, - SSL_SESSION_get_master_key, SSL_is_server, SSL_SESSION_up_ref}; + pub use ffi::{SSL_CTX_clear_options, SSL_CTX_get_options, SSL_CTX_set_options, SSL_CTX_up_ref, + SSL_SESSION_get_master_key, SSL_SESSION_up_ref, SSL_is_server}; pub unsafe fn get_new_idx(f: ffi::CRYPTO_EX_free) -> c_int { ffi::CRYPTO_get_ex_new_index( @@ -1942,7 +1921,7 @@ mod compat { use std::ptr; use ffi; - use libc::{self, c_long, c_ulong, c_int, size_t, c_uchar}; + use libc::{self, c_int, c_long, c_uchar, c_ulong, size_t}; pub unsafe fn SSL_CTX_get_options(ctx: *const ffi::SSL_CTX) -> c_ulong { ffi::SSL_CTX_ctrl(ctx as *mut _, ffi::SSL_CTRL_OPTIONS, 0, ptr::null_mut()) as c_ulong