From 3744e31e1652afc6ab648e777128918da68b478c Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Mon, 25 Dec 2017 21:35:09 -0700 Subject: [PATCH] Fix a bunch of FIXMEs --- openssl/src/dsa.rs | 14 ++++++-------- openssl/src/pkcs12.rs | 29 +++++++++++++---------------- openssl/src/rsa.rs | 26 ++++++++++++-------------- openssl/src/ssl/callbacks.rs | 17 ++++++++--------- openssl/src/ssl/connector.rs | 8 +++----- openssl/src/ssl/error.rs | 13 +++++-------- openssl/src/ssl/mod.rs | 35 ++++++++++++++++++++++------------- openssl/src/ssl/tests/mod.rs | 2 +- 8 files changed, 70 insertions(+), 74 deletions(-) diff --git a/openssl/src/dsa.rs b/openssl/src/dsa.rs index c687531e..e1b9a666 100644 --- a/openssl/src/dsa.rs +++ b/openssl/src/dsa.rs @@ -7,7 +7,7 @@ use ffi; use foreign_types::ForeignTypeRef; -use libc::{c_int, c_char, c_void}; +use libc::{c_char, c_int, c_void}; use std::fmt; use std::ptr; @@ -15,7 +15,7 @@ use {cvt, cvt_p}; use bio::MemBioSlice; use bn::BigNumRef; use error::ErrorStack; -use util::{CallbackState, invoke_passwd_cb_old}; +use util::{invoke_passwd_cb_old, CallbackState}; foreign_type_and_impl_send_sync! { type CType = ffi::DSA; @@ -70,12 +70,10 @@ impl DsaRef { /// OpenSSL documentation at [`DSA_size`] /// /// [`DSA_size`]: https://www.openssl.org/docs/man1.1.0/crypto/DSA_size.html - // FIXME should return u32 - pub fn size(&self) -> Option { - if self.q().is_some() { - unsafe { Some(ffi::DSA_size(self.as_ptr()) as u32) } - } else { - None + pub fn size(&self) -> u32 { + unsafe { + assert!(self.q().is_some()); + ffi::DSA_size(self.as_ptr()) as u32 } } diff --git a/openssl/src/pkcs12.rs b/openssl/src/pkcs12.rs index 84401a04..663e8700 100644 --- a/openssl/src/pkcs12.rs +++ b/openssl/src/pkcs12.rs @@ -9,7 +9,7 @@ use std::ffi::CString; use {cvt, cvt_p}; use pkey::{PKey, PKeyRef}; use error::ErrorStack; -use x509::X509; +use x509::{X509, X509Ref}; use stack::Stack; use nid::Nid; @@ -25,8 +25,7 @@ impl Pkcs12Ref { to_der!(ffi::i2d_PKCS12); /// Extracts the contents of the `Pkcs12`. - // FIXME should take an &[u8] - pub fn parse(&self, pass: &str) -> Result { + pub fn parse(&self, pass: &[u8]) -> Result { unsafe { let pass = CString::new(pass).unwrap(); @@ -46,9 +45,9 @@ impl Pkcs12Ref { let cert = X509::from_ptr(cert); let chain = if chain.is_null() { - Stack::new()? + None } else { - Stack::from_ptr(chain) + Some(Stack::from_ptr(chain)) }; Ok(ParsedPkcs12 { @@ -87,8 +86,7 @@ impl Pkcs12 { pub struct ParsedPkcs12 { pub pkey: PKey, pub cert: X509, - // FIXME Make this Option in the next breaking release - pub chain: Stack, + pub chain: Option>, } pub struct Pkcs12Builder { @@ -147,7 +145,7 @@ impl Pkcs12Builder { password: &str, friendly_name: &str, pkey: &PKeyRef, - cert: &X509, // FIXME X509Ref + cert: &X509Ref, ) -> Result { unsafe { let pass = CString::new(password).unwrap(); @@ -200,7 +198,7 @@ mod test { fn parse() { let der = include_bytes!("../test/identity.p12"); let pkcs12 = Pkcs12::from_der(der).unwrap(); - let parsed = pkcs12.parse("mypass").unwrap(); + let parsed = pkcs12.parse("mypass".as_bytes()).unwrap(); assert_eq!( parsed @@ -211,9 +209,10 @@ mod test { "59172d9313e84459bcff27f967e79e6e9217e584" ); - assert_eq!(parsed.chain.len(), 1); + let chain = parsed.chain.unwrap(); + assert_eq!(chain.len(), 1); assert_eq!( - parsed.chain[0] + chain[0] .fingerprint(MessageDigest::sha1()) .unwrap() .to_hex(), @@ -225,10 +224,8 @@ mod test { fn parse_empty_chain() { let der = include_bytes!("../test/keystore-empty-chain.p12"); let pkcs12 = Pkcs12::from_der(der).unwrap(); - let parsed = pkcs12.parse("cassandra").unwrap(); - - assert_eq!(parsed.chain.len(), 0); - assert_eq!(parsed.chain.into_iter().collect::>().len(), 0); + let parsed = pkcs12.parse("cassandra".as_bytes()).unwrap(); + assert!(parsed.chain.is_none()); } #[test] @@ -266,7 +263,7 @@ mod test { let der = pkcs12.to_der().unwrap(); let pkcs12 = Pkcs12::from_der(&der).unwrap(); - let parsed = pkcs12.parse("mypass").unwrap(); + let parsed = pkcs12.parse("mypass".as_bytes()).unwrap(); assert_eq!( parsed.cert.fingerprint(MessageDigest::sha1()).unwrap(), diff --git a/openssl/src/rsa.rs b/openssl/src/rsa.rs index 83f05247..5ec4d1a7 100644 --- a/openssl/src/rsa.rs +++ b/openssl/src/rsa.rs @@ -51,12 +51,10 @@ impl RsaRef { ffi::i2d_RSAPublicKey ); - // FIXME should return u32 - pub fn size(&self) -> usize { + pub fn size(&self) -> u32 { unsafe { assert!(self.n().is_some()); - - ffi::RSA_size(self.as_ptr()) as usize + ffi::RSA_size(self.as_ptr()) as u32 } } @@ -74,7 +72,7 @@ impl RsaRef { ) -> Result { assert!(self.d().is_some(), "private components missing"); assert!(from.len() <= i32::max_value() as usize); - assert!(to.len() >= self.size()); + assert!(to.len() >= self.size() as usize); unsafe { let len = cvt_n(ffi::RSA_private_decrypt( @@ -102,7 +100,7 @@ impl RsaRef { ) -> Result { assert!(self.d().is_some(), "private components missing"); assert!(from.len() <= i32::max_value() as usize); - assert!(to.len() >= self.size()); + assert!(to.len() >= self.size() as usize); unsafe { let len = cvt_n(ffi::RSA_private_encrypt( @@ -128,7 +126,7 @@ impl RsaRef { padding: Padding, ) -> Result { assert!(from.len() <= i32::max_value() as usize); - assert!(to.len() >= self.size()); + assert!(to.len() >= self.size() as usize); unsafe { let len = cvt_n(ffi::RSA_public_decrypt( @@ -154,7 +152,7 @@ impl RsaRef { padding: Padding, ) -> Result { assert!(from.len() <= i32::max_value() as usize); - assert!(to.len() >= self.size()); + assert!(to.len() >= self.size() as usize); unsafe { let len = cvt_n(ffi::RSA_public_encrypt( @@ -485,7 +483,7 @@ mod test { let key = include_bytes!("../test/rsa.pem.pub"); let public_key = Rsa::public_key_from_pem(key).unwrap(); - let mut result = vec![0; public_key.size()]; + let mut result = vec![0; public_key.size() as usize]; let original_data = b"This is test"; let len = public_key .public_encrypt(original_data, &mut result, Padding::PKCS1) @@ -494,7 +492,7 @@ mod test { let pkey = include_bytes!("../test/rsa.pem"); let private_key = Rsa::private_key_from_pem(pkey).unwrap(); - let mut dec_result = vec![0; private_key.size()]; + let mut dec_result = vec![0; private_key.size() as usize]; let len = private_key .private_decrypt(&result, &mut dec_result, Padding::PKCS1) .unwrap(); @@ -510,10 +508,10 @@ mod test { let msg = vec![0xdeu8, 0xadu8, 0xd0u8, 0x0du8]; - let mut emesg = vec![0; k0.size()]; + let mut emesg = vec![0; k0.size() as usize]; k0.private_encrypt(&msg, &mut emesg, Padding::PKCS1) .unwrap(); - let mut dmesg = vec![0; k1.size()]; + let mut dmesg = vec![0; k1.size() as usize]; let len = k1.public_decrypt(&emesg, &mut dmesg, Padding::PKCS1) .unwrap(); assert_eq!(msg, &dmesg[..len]); @@ -527,9 +525,9 @@ mod test { let msg = vec![0xdeu8, 0xadu8, 0xd0u8, 0x0du8]; - let mut emesg = vec![0; k0.size()]; + let mut emesg = vec![0; k0.size() as usize]; k0.public_encrypt(&msg, &mut emesg, Padding::PKCS1).unwrap(); - let mut dmesg = vec![0; k1.size()]; + let mut dmesg = vec![0; k1.size() as usize]; let len = k1.private_decrypt(&emesg, &mut dmesg, Padding::PKCS1) .unwrap(); assert_eq!(msg, &dmesg[..len]); diff --git a/openssl/src/ssl/callbacks.rs b/openssl/src/ssl/callbacks.rs index d7c48050..78602a54 100644 --- a/openssl/src/ssl/callbacks.rs +++ b/openssl/src/ssl/callbacks.rs @@ -1,5 +1,5 @@ use ffi; -use libc::{c_int, c_uint, c_char, c_uchar, c_void}; +use libc::{c_char, c_int, c_uchar, c_uint, c_void}; use std::any::Any; use std::ffi::CStr; use std::ptr; @@ -11,14 +11,14 @@ use error::ErrorStack; use dh::Dh; #[cfg(any(all(feature = "v101", ossl101), all(feature = "v102", ossl102)))] use ec_key::EcKey; -use ssl::{get_callback_idx, get_ssl_callback_idx, SslRef, SniError, NPN_PROTOS_IDX}; +use ssl::{get_callback_idx, get_ssl_callback_idx, SniError, SslRef, NPN_PROTOS_IDX}; #[cfg(any(all(feature = "v102", ossl102), all(feature = "v110", ossl110)))] use ssl::ALPN_PROTOS_IDX; use x509::X509StoreContextRef; pub extern "C" fn raw_verify(preverify_ok: c_int, x509_ctx: *mut ffi::X509_STORE_CTX) -> c_int where - F: Fn(bool, &X509StoreContextRef) -> bool + Any + 'static + Sync + Send, + F: Fn(bool, &mut X509StoreContextRef) -> bool + Any + 'static + Sync + Send, { unsafe { let idx = ffi::SSL_get_ex_data_X509_STORE_CTX_idx(); @@ -27,7 +27,7 @@ where let verify = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_callback_idx::()); let verify: &F = &*(verify as *mut F); - let ctx = X509StoreContextRef::from_ptr(x509_ctx); + let ctx = X509StoreContextRef::from_ptr_mut(x509_ctx); verify(preverify_ok != 0, ctx) as c_int } @@ -74,7 +74,7 @@ pub extern "C" fn ssl_raw_verify( x509_ctx: *mut ffi::X509_STORE_CTX, ) -> c_int where - F: Fn(bool, &X509StoreContextRef) -> bool + Any + 'static + Sync + Send, + F: Fn(bool, &mut X509StoreContextRef) -> bool + Any + 'static + Sync + Send, { unsafe { let idx = ffi::SSL_get_ex_data_X509_STORE_CTX_idx(); @@ -82,7 +82,7 @@ where let verify = ffi::SSL_get_ex_data(ssl as *const _, get_ssl_callback_idx::()); let verify: &F = &*(verify as *mut F); - let ctx = X509StoreContextRef::from_ptr(x509_ctx); + let ctx = X509StoreContextRef::from_ptr_mut(x509_ctx); verify(preverify_ok != 0, ctx) as c_int } @@ -121,7 +121,6 @@ pub unsafe fn select_proto_using( inlen: c_uint, ex_data: c_int, ) -> c_int { - // First, get the list of protocols (that the client should support) saved in the context // extra data. let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl); @@ -132,8 +131,8 @@ pub unsafe fn select_proto_using( let client_len = protocols.len() as c_uint; // Finally, let OpenSSL find a protocol to be used, by matching the given server and // client lists. - if ffi::SSL_select_next_proto(out, outlen, inbuf, inlen, client, client_len) != - ffi::OPENSSL_NPN_NEGOTIATED + if ffi::SSL_select_next_proto(out, outlen, inbuf, inlen, client, client_len) + != ffi::OPENSSL_NPN_NEGOTIATED { ffi::SSL_TLSEXT_ERR_NOACK } else { diff --git a/openssl/src/ssl/connector.rs b/openssl/src/ssl/connector.rs index dc65ad5e..dc13ea97 100644 --- a/openssl/src/ssl/connector.rs +++ b/openssl/src/ssl/connector.rs @@ -63,11 +63,9 @@ impl SslConnectorBuilder { ctx.set_default_verify_paths()?; // From https://github.com/python/cpython/blob/a170fa162dc03f0a014373349e548954fff2e567/Lib/ssl.py#L193 ctx.set_cipher_list( - "TLS13-AES-256-GCM-SHA384:TLS13-CHACHA20-POLY1305-SHA256:\ - TLS13-AES-128-GCM-SHA256:\ - ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:\ - ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:\ - !aNULL:!eNULL:!MD5:!3DES", + "TLS13-AES-256-GCM-SHA384:TLS13-CHACHA20-POLY1305-SHA256:TLS13-AES-128-GCM-SHA256:\ + ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:ECDH+AES128:\ + DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:!aNULL:!eNULL:!MD5:!3DES", )?; setup_verify(&mut ctx); diff --git a/openssl/src/ssl/error.rs b/openssl/src/ssl/error.rs index 48f49979..358a88b0 100644 --- a/openssl/src/ssl/error.rs +++ b/openssl/src/ssl/error.rs @@ -101,11 +101,10 @@ pub enum HandshakeError { SetupFailure(ErrorStack), /// The handshake failed. Failure(MidHandshakeSslStream), - /// The handshake was interrupted midway through. + /// The handshake encountered a `WouldBlock` error midway through. /// /// This error will never be returned for blocking streams. - // FIXME change to WouldBlock - Interrupted(MidHandshakeSslStream), + WouldBlock(MidHandshakeSslStream), } impl StdError for HandshakeError { @@ -113,15 +112,14 @@ impl StdError for HandshakeError { match *self { HandshakeError::SetupFailure(_) => "stream setup failed", HandshakeError::Failure(_) => "the handshake failed", - HandshakeError::Interrupted(_) => "the handshake was interrupted", + HandshakeError::WouldBlock(_) => "the handshake was interrupted", } } fn cause(&self) -> Option<&StdError> { match *self { HandshakeError::SetupFailure(ref e) => Some(e), - HandshakeError::Failure(ref s) | - HandshakeError::Interrupted(ref s) => Some(s.error()), + HandshakeError::Failure(ref s) | HandshakeError::WouldBlock(ref s) => Some(s.error()), } } } @@ -131,8 +129,7 @@ impl fmt::Display for HandshakeError { f.write_str(StdError::description(self))?; match *self { HandshakeError::SetupFailure(ref e) => write!(f, ": {}", e)?, - HandshakeError::Failure(ref s) | - HandshakeError::Interrupted(ref s) => { + HandshakeError::Failure(ref s) | HandshakeError::WouldBlock(ref s) => { write!(f, ": {}", s.error())?; if let Some(err) = s.ssl().verify_result() { write!(f, ": {}", err)?; diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 371b66c0..c59ad8d8 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -41,12 +41,12 @@ //! let mut pkcs12 = vec![]; //! file.read_to_end(&mut pkcs12).unwrap(); //! let pkcs12 = Pkcs12::from_der(&pkcs12).unwrap(); -//! let identity = pkcs12.parse("password123").unwrap(); +//! let identity = pkcs12.parse(b"password123").unwrap(); //! //! let acceptor = SslAcceptorBuilder::mozilla_intermediate(SslMethod::tls(), //! &identity.pkey, //! &identity.cert, -//! &identity.chain) +//! &identity.chain.unwrap()) //! .unwrap() //! .build(); //! let acceptor = Arc::new(acceptor); @@ -83,7 +83,7 @@ use std::fmt; use std::io; use std::io::prelude::*; use std::marker::PhantomData; -use std::mem; +use std::mem::{self, ManuallyDrop}; use std::ops::{Deref, DerefMut}; use std::panic::resume_unwind; use std::path::Path; @@ -490,7 +490,7 @@ impl SslContextBuilder { pub fn set_verify_callback(&mut self, mode: SslVerifyMode, verify: F) where // FIXME should take a mutable reference to the store - F: Fn(bool, &X509StoreContextRef) -> bool + Any + 'static + Sync + Send, + F: Fn(bool, &mut X509StoreContextRef) -> bool + Any + 'static + Sync + Send, { unsafe { let verify = Box::new(verify); @@ -1500,7 +1500,7 @@ impl SslRef { pub fn set_verify_callback(&mut self, mode: SslVerifyMode, verify: F) where // FIXME should take a mutable reference to the x509 store - F: Fn(bool, &X509StoreContextRef) -> bool + Any + 'static + Sync + Send, + F: Fn(bool, &mut X509StoreContextRef) -> bool + Any + 'static + Sync + Send, { unsafe { let verify = Box::new(verify); @@ -2069,7 +2069,7 @@ impl Ssl { } else { match stream.make_error(ret) { e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { - Err(HandshakeError::Interrupted(MidHandshakeSslStream { + Err(HandshakeError::WouldBlock(MidHandshakeSslStream { stream: stream, error: e, })) @@ -2103,7 +2103,7 @@ impl Ssl { } else { match stream.make_error(ret) { e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { - Err(HandshakeError::Interrupted(MidHandshakeSslStream { + Err(HandshakeError::WouldBlock(MidHandshakeSslStream { stream: stream, error: e, })) @@ -2163,7 +2163,7 @@ impl MidHandshakeSslStream { match self.stream.make_error(ret) { e @ Error::WantWrite(_) | e @ Error::WantRead(_) => { self.error = e; - Err(HandshakeError::Interrupted(self)) + Err(HandshakeError::WouldBlock(self)) } err => { self.error = err; @@ -2176,12 +2176,21 @@ impl MidHandshakeSslStream { /// A TLS session over a stream. pub struct SslStream { - // FIXME use ManuallyDrop - ssl: Ssl, - _method: BioMethod, // NOTE: this *must* be after the Ssl field so things drop right + ssl: ManuallyDrop, + method: ManuallyDrop, _p: PhantomData, } +impl Drop for SslStream { + fn drop(&mut self) { + // ssl holds a reference to method internally so it has to drop first + unsafe { + ManuallyDrop::drop(&mut self.ssl); + ManuallyDrop::drop(&mut self.method); + } + } +} + impl fmt::Debug for SslStream where S: fmt::Debug, @@ -2201,8 +2210,8 @@ impl SslStream { ffi::SSL_set_bio(ssl.as_ptr(), bio, bio); SslStream { - ssl: ssl, - _method: method, + ssl: ManuallyDrop::new(ssl), + method: ManuallyDrop::new(method), _p: PhantomData, } } diff --git a/openssl/src/ssl/tests/mod.rs b/openssl/src/ssl/tests/mod.rs index 365f0168..69369ed4 100644 --- a/openssl/src/ssl/tests/mod.rs +++ b/openssl/src/ssl/tests/mod.rs @@ -753,7 +753,7 @@ fn wait_io(stream: &TcpStream, read: bool, timeout_ms: u32) -> bool { fn handshake(res: Result, HandshakeError>) -> SslStream { match res { Ok(s) => s, - Err(HandshakeError::Interrupted(s)) => { + Err(HandshakeError::WouldBlock(s)) => { wait_io(s.get_ref(), true, 1_000); handshake(s.handshake()) }