From 8f56897043f8138980ce3376765b769c764d8701 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Wed, 9 Dec 2015 22:02:02 -0800 Subject: [PATCH] Implement read and write --- openssl-sys-extras/src/lib.rs | 2 + openssl-sys-extras/src/openssl_shim.c | 4 + openssl/src/ssl/mod.rs | 226 +++++++++++++++----------- openssl/src/ssl/tests/mod.rs | 12 ++ 4 files changed, 148 insertions(+), 96 deletions(-) diff --git a/openssl-sys-extras/src/lib.rs b/openssl-sys-extras/src/lib.rs index dfeb06e5..3c114726 100644 --- a/openssl-sys-extras/src/lib.rs +++ b/openssl-sys-extras/src/lib.rs @@ -55,6 +55,8 @@ extern { pub fn BIO_set_retry_read(b: *mut BIO); #[link_name = "BIO_set_retry_write_shim"] pub fn BIO_set_retry_write(b: *mut BIO); + #[link_name = "BIO_flush"] + pub fn BIO_flush(b: *mut BIO) -> c_long; pub fn SSL_CTX_set_options_shim(ctx: *mut SSL_CTX, options: c_long) -> c_long; pub fn SSL_CTX_get_options_shim(ctx: *mut SSL_CTX) -> c_long; pub fn SSL_CTX_clear_options_shim(ctx: *mut SSL_CTX, options: c_long) -> c_long; diff --git a/openssl-sys-extras/src/openssl_shim.c b/openssl-sys-extras/src/openssl_shim.c index 95847ac1..cc42fbf4 100644 --- a/openssl-sys-extras/src/openssl_shim.c +++ b/openssl-sys-extras/src/openssl_shim.c @@ -103,6 +103,10 @@ void BIO_set_retry_write_shim(BIO *b) { BIO_set_retry_write(b); } +long BIO_flush_shim(BIO *b) { + return BIO_flush(b); +} + long SSL_CTX_set_options_shim(SSL_CTX *ctx, long options) { return SSL_CTX_set_options(ctx, options); } diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 5d24fc32..e12d694d 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -1350,102 +1350,6 @@ impl Write for SslStream { } } -pub struct SslStreamNg { - ssl: Ssl, - _method: Box, // :( - _p: PhantomData, -} - -impl Drop for SslStreamNg { - fn drop(&mut self) { - unsafe { - let _ = bio::take_stream::(self.ssl.get_raw_rbio()); - } - } -} - -impl SslStreamNg { - fn new_base(ssl: Ssl, stream: S) -> Result { - unsafe { - let (bio, method) = try!(bio::new(stream)); - ffi::SSL_set_bio(ssl.ssl, bio, bio); - - Ok(SslStreamNg { - ssl: ssl, - _method: method, - _p: PhantomData, - }) - } - } - - /// Creates an SSL/TLS client operating over the provided stream. - pub fn connect(ssl: T, stream: S) -> Result { - let ssl = try!(ssl.into_ssl()); - let mut stream = try!(Self::new_base(ssl, stream)); - let ret = stream.ssl.connect(); - if ret > 0 { - Ok(stream) - } else { - Err(stream.make_error(ret)) - } - } - - /// Creates an SSL/TLS server operating over the provided stream. - pub fn accept(ssl: T, stream: S) -> Result { - let ssl = try!(ssl.into_ssl()); - let mut stream = try!(Self::new_base(ssl, stream)); - let ret = stream.ssl.accept(); - if ret > 0 { - Ok(stream) - } else { - Err(stream.make_error(ret)) - } - } - - pub fn get_ref(&self) -> &S { - unsafe { - let bio = self.ssl.get_raw_rbio(); - bio::get_ref(bio) - } - } - - pub fn mut_ref(&mut self) -> &mut S { - unsafe { - let bio = self.ssl.get_raw_rbio(); - bio::get_mut(bio) - } - } - - fn make_error(&mut self, ret: c_int) -> SslError { - match self.ssl.get_error(ret) { - LibSslError::ErrorSsl => SslError::get(), - LibSslError::ErrorSyscall => { - let err = SslError::get(); - let count = match err { - SslError::OpenSslErrors(ref v) => v.len(), - _ => unreachable!(), - }; - if count == 0 { - if ret == 0 { - SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, - "unexpected EOF observed")) - } else { - let error = unsafe { bio::take_error::(self.ssl.get_raw_rbio()) }; - SslError::StreamError(error.unwrap()) - } - } else { - err - } - } - LibSslError::ErrorWantWrite | LibSslError::ErrorWantRead => { - let error = unsafe { bio::take_error::(self.ssl.get_raw_rbio()) }; - SslError::StreamError(error.unwrap()) - } - err => panic!("unexpected error {:?} with ret {}", err, ret), - } - } -} - pub trait IntoSsl { fn into_ssl(self) -> Result; } @@ -1756,3 +1660,133 @@ impl NonblockingSslStream { } } } + +pub struct SslStreamNg { + ssl: Ssl, + _method: Box, // :( + _p: PhantomData, +} + +impl Drop for SslStreamNg { + fn drop(&mut self) { + unsafe { + let _ = bio::take_stream::(self.ssl.get_raw_rbio()); + } + } +} + +impl SslStreamNg { + fn new_base(ssl: Ssl, stream: S) -> Result { + unsafe { + let (bio, method) = try!(bio::new(stream)); + ffi::SSL_set_bio(ssl.ssl, bio, bio); + + Ok(SslStreamNg { + ssl: ssl, + _method: method, + _p: PhantomData, + }) + } + } + + /// Creates an SSL/TLS client operating over the provided stream. + pub fn connect(ssl: T, stream: S) -> Result { + let ssl = try!(ssl.into_ssl()); + let mut stream = try!(Self::new_base(ssl, stream)); + let ret = stream.ssl.connect(); + if ret > 0 { + Ok(stream) + } else { + Err(stream.make_error(ret)) + } + } + + /// Creates an SSL/TLS server operating over the provided stream. + pub fn accept(ssl: T, stream: S) -> Result { + let ssl = try!(ssl.into_ssl()); + let mut stream = try!(Self::new_base(ssl, stream)); + let ret = stream.ssl.accept(); + if ret > 0 { + Ok(stream) + } else { + Err(stream.make_error(ret)) + } + } +} + +impl SslStreamNg { + pub fn get_ref(&self) -> &S { + unsafe { + let bio = self.ssl.get_raw_rbio(); + bio::get_ref(bio) + } + } + + pub fn get_mut(&mut self) -> &mut S { + unsafe { + let bio = self.ssl.get_raw_rbio(); + bio::get_mut(bio) + } + } + + fn make_error(&mut self, ret: c_int) -> SslError { + match self.ssl.get_error(ret) { + LibSslError::ErrorSsl => SslError::get(), + LibSslError::ErrorSyscall => { + let err = SslError::get(); + let count = match err { + SslError::OpenSslErrors(ref v) => v.len(), + _ => unreachable!(), + }; + if count == 0 { + if ret == 0 { + SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, + "unexpected EOF observed")) + } else { + let error = unsafe { bio::take_error::(self.ssl.get_raw_rbio()) }; + SslError::StreamError(error.unwrap()) + } + } else { + err + } + } + LibSslError::ErrorWantWrite | LibSslError::ErrorWantRead => { + let error = unsafe { bio::take_error::(self.ssl.get_raw_rbio()) }; + SslError::StreamError(error.unwrap()) + } + err => panic!("unexpected error {:?} with ret {}", err, ret), + } + } +} + +impl Read for SslStreamNg { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let ret = self.ssl.read(buf); + if ret >= 0 { + return Ok(ret as usize); + } + + match self.make_error(ret) { + SslError::StreamError(e) => Err(e), + e => Err(io::Error::new(io::ErrorKind::Other, e)), + } + } +} + +impl Write for SslStreamNg { + fn write(&mut self, buf: &[u8]) -> io::Result { + let ret = self.ssl.write(buf); + if ret > 0 { + return Ok(ret as usize); + } + + match self.make_error(ret) { + SslError::StreamError(e) => Err(e), + e => Err(io::Error::new(io::ErrorKind::Other, e)), + } + } + + fn flush(&mut self) -> io::Result<()> { + self.get_mut().flush() + } +} diff --git a/openssl/src/ssl/tests/mod.rs b/openssl/src/ssl/tests/mod.rs index d1f34019..13d9371c 100644 --- a/openssl/src/ssl/tests/mod.rs +++ b/openssl/src/ssl/tests/mod.rs @@ -936,3 +936,15 @@ fn ng_connect() { let ctx = SslContext::new(Sslv23).unwrap(); SslStreamNg::connect(&ctx, stream).unwrap(); } + +#[test] +fn ng_get() { + let (_s, stream) = Server::new(); + let ctx = SslContext::new(Sslv23).unwrap(); + let mut stream = SslStreamNg::connect(&ctx, stream).unwrap(); + stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); + let mut resp = String::new(); + stream.read_to_string(&mut resp).unwrap(); + assert!(resp.starts_with("HTTP/1.0 200")); + assert!(resp.ends_with("\r\n\r\n")); +}