diff --git a/openssl-sys/src/lib.rs b/openssl-sys/src/lib.rs index 519b0001..28d4a59c 100644 --- a/openssl-sys/src/lib.rs +++ b/openssl-sys/src/lib.rs @@ -566,6 +566,9 @@ extern "C" { pub fn PEM_read_bio_PUBKEY(bio: *mut BIO, out: *mut *mut EVP_PKEY, callback: Option, user_data: *mut c_void) -> *mut X509; + pub fn PEM_read_bio_RSAPrivateKey(bio: *mut BIO, rsa: *mut *mut RSA, callback: Option, user_data: *mut c_void) -> *mut RSA; + pub fn PEM_read_bio_RSA_PUBKEY(bio: *mut BIO, rsa: *mut *mut RSA, callback: Option, user_data: *mut c_void) -> *mut RSA; + pub fn PEM_write_bio_PrivateKey(bio: *mut BIO, pkey: *mut EVP_PKEY, cipher: *const EVP_CIPHER, kstr: *mut c_char, klen: c_int, callback: Option, diff --git a/openssl/Cargo.toml b/openssl/Cargo.toml index 9c56623e..2dd9c32f 100644 --- a/openssl/Cargo.toml +++ b/openssl/Cargo.toml @@ -25,6 +25,8 @@ rfc5114 = ["openssl-sys/rfc5114"] ecdh_auto = ["openssl-sys-extras/ecdh_auto"] pkcs5_pbkdf2_hmac = ["openssl-sys/pkcs5_pbkdf2_hmac"] +nightly = [] + [dependencies] bitflags = ">= 0.2, < 0.4" lazy_static = "0.1" diff --git a/openssl/src/c_helpers.c b/openssl/src/c_helpers.c index 402c36ec..1b48565e 100644 --- a/openssl/src/c_helpers.c +++ b/openssl/src/c_helpers.c @@ -7,3 +7,11 @@ void rust_SSL_clone(SSL *ssl) { void rust_SSL_CTX_clone(SSL_CTX *ctx) { CRYPTO_add(&ctx->references,1,CRYPTO_LOCK_SSL_CTX); } + +void rust_EVP_PKEY_clone(EVP_PKEY *pkey) { + CRYPTO_add(&pkey->references,1,CRYPTO_LOCK_EVP_PKEY); +} + +void rust_X509_clone(X509 *x509) { + CRYPTO_add(&x509->references,1,CRYPTO_LOCK_X509); +} diff --git a/openssl/src/crypto/pkey.rs b/openssl/src/crypto/pkey.rs index 25ce28e8..dee6cb8b 100644 --- a/openssl/src/crypto/pkey.rs +++ b/openssl/src/crypto/pkey.rs @@ -52,11 +52,18 @@ fn openssl_hash_nid(hash: HashType) -> c_int { } } +extern "C" { + fn rust_EVP_PKEY_clone(pkey: *mut ffi::EVP_PKEY); +} + pub struct PKey { evp: *mut ffi::EVP_PKEY, parts: Parts, } +unsafe impl Send for PKey {} +unsafe impl Sync for PKey {} + /// Represents a public key, optionally with a private key attached. impl PKey { pub fn new() -> PKey { @@ -118,6 +125,54 @@ impl PKey { } } + /// Reads an RSA private key from PEM, takes ownership of handle + pub fn private_rsa_key_from_pem(reader: &mut R) -> Result + where R: Read + { + let mut mem_bio = try!(MemBio::new()); + try!(io::copy(reader, &mut mem_bio).map_err(StreamError)); + + unsafe { + let rsa = try_ssl_null!(ffi::PEM_read_bio_RSAPrivateKey(mem_bio.get_handle(), + ptr::null_mut(), + None, + ptr::null_mut())); + let evp = ffi::EVP_PKEY_new(); + if ffi::EVP_PKEY_set1_RSA(evp, rsa) == 0 { + return Err(SslError::get()); + } + + Ok(PKey { + evp: evp, + parts: Parts::Public, + }) + } + } + + /// Reads an RSA public key from PEM, takes ownership of handle + pub fn public_rsa_key_from_pem(reader: &mut R) -> Result + where R: Read + { + let mut mem_bio = try!(MemBio::new()); + try!(io::copy(reader, &mut mem_bio).map_err(StreamError)); + + unsafe { + let rsa = try_ssl_null!(ffi::PEM_read_bio_RSA_PUBKEY(mem_bio.get_handle(), + ptr::null_mut(), + None, + ptr::null_mut())); + let evp = ffi::EVP_PKEY_new(); + if ffi::EVP_PKEY_set1_RSA(evp, rsa) == 0 { + return Err(SslError::get()); + } + + Ok(PKey { + evp: evp, + parts: Parts::Public, + }) + } + } + fn _tostr(&self, f: unsafe extern "C" fn(*mut ffi::RSA, *const *mut u8) -> c_int) -> Vec { unsafe { let rsa = ffi::EVP_PKEY_get1_RSA(self.evp); @@ -549,6 +604,16 @@ impl Drop for PKey { } } +impl Clone for PKey { + fn clone(&self) -> Self { + unsafe { + rust_EVP_PKEY_clone(self.evp); + } + + PKey::from_handle(self.evp, self.parts) + } +} + #[cfg(test)] mod tests { use std::path::Path; @@ -613,6 +678,26 @@ mod tests { super::PKey::public_key_from_pem(&mut file).unwrap(); } + #[test] + fn test_private_rsa_key_from_pem() { + let key_path = Path::new("test/key.pem"); + let mut file = File::open(&key_path) + .ok() + .expect("Failed to open `test/key.pem`"); + + super::PKey::private_rsa_key_from_pem(&mut file).unwrap(); + } + + #[test] + fn test_public_rsa_key_from_pem() { + let key_path = Path::new("test/key.pem.pub"); + let mut file = File::open(&key_path) + .ok() + .expect("Failed to open `test/key.pem.pub`"); + + super::PKey::public_rsa_key_from_pem(&mut file).unwrap(); + } + #[test] fn test_private_encrypt() { let mut k0 = super::PKey::new(); diff --git a/openssl/src/lib.rs b/openssl/src/lib.rs index 88b67d97..54feab0d 100644 --- a/openssl/src/lib.rs +++ b/openssl/src/lib.rs @@ -1,4 +1,5 @@ #![doc(html_root_url="https://sfackler.github.io/rust-openssl/doc/v0.7.4")] +#![cfg_attr(feature = "nightly", feature(const_fn, recover, panic_propagate))] #[macro_use] extern crate bitflags; diff --git a/openssl/src/ssl/bio.rs b/openssl/src/ssl/bio.rs index a361ae81..aa445562 100644 --- a/openssl/src/ssl/bio.rs +++ b/openssl/src/ssl/bio.rs @@ -1,11 +1,12 @@ use libc::{c_char, c_int, c_long, c_void, strlen}; use ffi::{BIO, BIO_METHOD, BIO_CTRL_FLUSH, BIO_TYPE_NONE, BIO_new}; use ffi_extras::{BIO_clear_retry_flags, BIO_set_retry_read, BIO_set_retry_write}; +use std::any::Any; use std::io; use std::io::prelude::*; use std::mem; -use std::slice; use std::ptr; +use std::slice; use std::sync::Arc; use ssl::error::SslError; @@ -16,6 +17,7 @@ const NAME: [c_char; 5] = [114, 117, 115, 116, 0]; pub struct StreamState { pub stream: S, pub error: Option, + pub panic: Option>, } pub fn new(stream: S) -> Result<(*mut BIO, Arc), SslError> { @@ -35,6 +37,7 @@ pub fn new(stream: S) -> Result<(*mut BIO, Arc), Ss let state = Box::new(StreamState { stream: stream, error: None, + panic: None, }); unsafe { @@ -51,6 +54,12 @@ pub unsafe fn take_error(bio: *mut BIO) -> Option { state.error.take() } +#[cfg_attr(not(feature = "nightly"), allow(dead_code))] +pub unsafe fn take_panic(bio: *mut BIO) -> Option> { + let state = state::(bio); + state.panic.take() +} + pub unsafe fn get_ref<'a, S: 'a>(bio: *mut BIO) -> &'a S { let state: &'a StreamState = mem::transmute((*bio).ptr); &state.stream @@ -64,20 +73,69 @@ unsafe fn state<'a, S: 'a>(bio: *mut BIO) -> &'a mut StreamState { mem::transmute((*bio).ptr) } +#[cfg(feature = "nightly")] +fn recover(f: F) -> Result> where F: FnOnce() -> T + ::std::panic::RecoverSafe { + ::std::panic::recover(f) +} + +#[cfg(not(feature = "nightly"))] +fn recover(f: F) -> Result> where F: FnOnce() -> T { + Ok(f()) +} + +#[cfg(feature = "nightly")] +use std::panic::AssertRecoverSafe; + +#[cfg(not(feature = "nightly"))] +struct AssertRecoverSafe(T); + +#[cfg(not(feature = "nightly"))] +impl AssertRecoverSafe { + fn new(t: T) -> Self { + AssertRecoverSafe(t) + } +} + +#[cfg(not(feature = "nightly"))] +impl ::std::ops::Deref for AssertRecoverSafe { + type Target = T; + + fn deref(&self) -> &T { + &self.0 + } +} + +#[cfg(not(feature = "nightly"))] +impl ::std::ops::DerefMut for AssertRecoverSafe { + fn deref_mut(&mut self) -> &mut T { + &mut self.0 + } +} + unsafe extern "C" fn bwrite(bio: *mut BIO, buf: *const c_char, len: c_int) -> c_int { BIO_clear_retry_flags(bio); let state = state::(bio); let buf = slice::from_raw_parts(buf as *const _, len as usize); - match state.stream.write(buf) { - Ok(len) => len as c_int, - Err(err) => { + + let result = { + let mut youre_not_my_supervisor = AssertRecoverSafe::new(&mut *state); + recover(move || youre_not_my_supervisor.stream.write(buf)) + }; + + match result { + Ok(Ok(len)) => len as c_int, + Ok(Err(err)) => { if retriable_error(&err) { BIO_set_retry_write(bio); } state.error = Some(err); -1 } + Err(err) => { + state.panic = Some(err); + -1 + } } } @@ -86,15 +144,26 @@ unsafe extern "C" fn bread(bio: *mut BIO, buf: *mut c_char, len: c_int) let state = state::(bio); let buf = slice::from_raw_parts_mut(buf as *mut _, len as usize); - match state.stream.read(buf) { - Ok(len) => len as c_int, - Err(err) => { + + let result = { + let mut youre_not_my_supervisor = AssertRecoverSafe::new(&mut *state); + let mut fuuuu = AssertRecoverSafe::new(buf); + recover(move || youre_not_my_supervisor.stream.read(&mut *fuuuu)) + }; + + match result { + Ok(Ok(len)) => len as c_int, + Ok(Err(err)) => { if retriable_error(&err) { BIO_set_retry_read(bio); } state.error = Some(err); -1 } + Err(err) => { + state.panic = Some(err); + -1 + } } } @@ -116,12 +185,21 @@ unsafe extern "C" fn ctrl(bio: *mut BIO, -> c_long { if cmd == BIO_CTRL_FLUSH { let state = state::(bio); - match state.stream.flush() { - Ok(()) => 1, - Err(err) => { + let result = { + let mut youre_not_my_supervisor = AssertRecoverSafe::new(&mut *state); + recover(move || youre_not_my_supervisor.stream.flush()) + }; + + match result { + Ok(Ok(())) => 1, + Ok(Err(err)) => { state.error = Some(err); 0 } + Err(err) => { + state.panic = Some(err); + 0 + } } } else { 0 diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 955f10fd..7366cc4a 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -26,8 +26,7 @@ use std::os::windows::io::{AsRawSocket, RawSocket}; use ffi; use ffi_extras; use dh::DH; -use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError, - OpensslError}; +use ssl::error::{NonblockingSslError, SslError, OpenSslError, OpensslError}; use x509::{X509StoreContext, X509FileType, X509}; use crypto::pkey::PKey; @@ -482,6 +481,8 @@ fn wrap_ssl_result(res: c_int) -> Result<(), SslError> { } /// An SSL context object +/// +/// Internally ref-counted, use `.clone()` in the same way as Rc and Arc. pub struct SslContext { ctx: *mut ffi::SSL_CTX, } @@ -489,6 +490,12 @@ pub struct SslContext { unsafe impl Send for SslContext {} unsafe impl Sync for SslContext {} +impl Clone for SslContext { + fn clone(&self) -> Self { + unsafe { SslContext::new_ref(self.ctx) } + } +} + // TODO: add useful info here impl fmt::Debug for SslContext { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { @@ -503,6 +510,12 @@ impl Drop for SslContext { } impl SslContext { + // Create a new SslContext given an existing ref, and incriment ref-count appropriately. + unsafe fn new_ref(ctx: *mut ffi::SSL_CTX) -> SslContext { + rust_SSL_CTX_clone(ctx); + SslContext { ctx: ctx } + } + /// Creates a new SSL context. pub fn new(method: SslMethod) -> Result { init(); @@ -956,16 +969,27 @@ impl Ssl { } /// change the context corresponding to the current connection + /// + /// Returns a clone of the SslContext @ctx (ie: the new context). The old context is freed. pub fn set_ssl_context(&self, ctx: &SslContext) -> SslContext { - SslContext { ctx: unsafe { ffi::SSL_set_SSL_CTX(self.ssl, ctx.ctx) } } + // If duplication of @ctx's cert fails, this returns NULL. This _appears_ to only occur on + // allocation failures (meaning panicing is probably appropriate), but it might be nice to + // propogate the error. + assert!(unsafe { ffi::SSL_set_SSL_CTX(self.ssl, ctx.ctx) } != ptr::null_mut()); + + // FIXME: we return this reference here for compatibility, but it isn't actually required. + // This should be removed when a api-incompatabile version is to be released. + // + // ffi:SSL_set_SSL_CTX() returns copy of the ctx pointer passed to it, so it's easier for + // us to do the clone directly. + ctx.clone() } /// obtain the context corresponding to the current connection pub fn get_ssl_context(&self) -> SslContext { unsafe { let ssl_ctx = ffi::SSL_get_SSL_CTX(self.ssl); - rust_SSL_CTX_clone(ssl_ctx); - SslContext { ctx: ssl_ctx } + SslContext::new_ref(ssl_ctx) } } } @@ -1137,6 +1161,8 @@ impl SslStream { impl SslStream { fn make_error(&mut self, ret: c_int) -> Error { + self.check_panic(); + match self.ssl.get_error(ret) { LibSslError::ErrorSsl => Error::Ssl(OpenSslError::get_stack()), LibSslError::ErrorSyscall => { @@ -1163,6 +1189,8 @@ impl SslStream { } fn make_old_error(&mut self, ret: c_int) -> SslError { + self.check_panic(); + match self.ssl.get_error(ret) { LibSslError::ErrorSsl => SslError::get(), LibSslError::ErrorSyscall => { @@ -1193,6 +1221,17 @@ impl SslStream { } } + #[cfg(feature = "nightly")] + fn check_panic(&mut self) { + if let Some(err) = unsafe { bio::take_panic::(self.ssl.get_raw_rbio()) } { + ::std::panic::propagate(err) + } + } + + #[cfg(not(feature = "nightly"))] + fn check_panic(&mut self) { + } + fn get_bio_error(&mut self) -> io::Error { let error = unsafe { bio::take_error::(self.ssl.get_raw_rbio()) }; match error { diff --git a/openssl/src/ssl/tests/mod.rs b/openssl/src/ssl/tests/mod.rs index af3c005e..f5a42536 100644 --- a/openssl/src/ssl/tests/mod.rs +++ b/openssl/src/ssl/tests/mod.rs @@ -957,3 +957,104 @@ fn broken_try_clone_doesnt_crash() { let stream1 = SslStream::connect(&context, inner).unwrap(); let _stream2 = stream1.try_clone().unwrap(); } + +#[test] +#[should_panic(expected = "blammo")] +#[cfg(feature = "nightly")] +fn write_panic() { + struct ExplodingStream(TcpStream); + + impl Read for ExplodingStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } + } + + impl Write for ExplodingStream { + fn write(&mut self, _: &[u8]) -> io::Result { + panic!("blammo"); + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } + } + + let (_s, stream) = Server::new(); + let stream = ExplodingStream(stream); + + let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); + let _ = SslStream::connect(&ctx, stream); +} + +#[test] +#[should_panic(expected = "blammo")] +#[cfg(feature = "nightly")] +fn read_panic() { + struct ExplodingStream(TcpStream); + + impl Read for ExplodingStream { + fn read(&mut self, _: &mut [u8]) -> io::Result { + panic!("blammo"); + } + } + + impl Write for ExplodingStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } + } + + let (_s, stream) = Server::new(); + let stream = ExplodingStream(stream); + + let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); + let _ = SslStream::connect(&ctx, stream); +} + +#[test] +#[should_panic(expected = "blammo")] +#[cfg(feature = "nightly")] +fn flush_panic() { + struct ExplodingStream(TcpStream); + + impl Read for ExplodingStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } + } + + impl Write for ExplodingStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + panic!("blammo"); + } + } + + let (_s, stream) = Server::new(); + let stream = ExplodingStream(stream); + + let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); + let mut stream = SslStream::connect(&ctx, stream).unwrap(); + let _ = stream.flush(); +} + +#[test] +fn refcount_ssl_context() { + let ssl = { + let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); + ssl::Ssl::new(&ctx).unwrap() + }; + + { + let new_ctx_a = SslContext::new(SslMethod::Sslv23).unwrap(); + let _new_ctx_b = ssl.set_ssl_context(&new_ctx_a); + } +} diff --git a/openssl/src/x509/mod.rs b/openssl/src/x509/mod.rs index ffd478ef..f31de89b 100644 --- a/openssl/src/x509/mod.rs +++ b/openssl/src/x509/mod.rs @@ -507,6 +507,20 @@ impl<'ctx> X509<'ctx> { } } +extern "C" { + fn rust_X509_clone(x509: *mut ffi::X509); +} + +impl<'ctx> Clone for X509<'ctx> { + fn clone(&self) -> X509<'ctx> { + unsafe { rust_X509_clone(self.handle) } + /* FIXME: given that we now have refcounting control, 'owned' should be uneeded, the 'ctx + * is probably also uneeded. We can remove both to condense the x509 api quite a bit + */ + X509::new(self.handle, true) + } +} + impl<'ctx> Drop for X509<'ctx> { fn drop(&mut self) { if self.owned { diff --git a/openssl/test/run.sh b/openssl/test/run.sh index 229d9a1d..63b2b57c 100755 --- a/openssl/test/run.sh +++ b/openssl/test/run.sh @@ -7,6 +7,10 @@ if [ "$TEST_FEATURES" == "true" ]; then FEATURES="tlsv1_2 tlsv1_1 dtlsv1 dtlsv1_2 sslv2 sslv3 aes_xts aes_ctr npn alpn rfc5114 ecdh_auto pkcs5_pbkdf2_hmac" fi +if [ "$TRAVIS_RUST_VERSION" == "nightly" ]; then + FEATURES="$FEATURES nightly" +fi + if [ "$TRAVIS_OS_NAME" != "osx" ]; then export OPENSSL_LIB_DIR=$HOME/openssl/lib export OPENSSL_INCLUDE_DIR=$HOME/openssl/include