From 9b235a7b9121613780810b0bc7b4d1f30dc861c9 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 27 Jun 2015 22:37:10 -0700 Subject: [PATCH] Prepare for direct stream support --- openssl/src/ssl/mod.rs | 263 +++++++++++++++++++++++++++-------------- 1 file changed, 175 insertions(+), 88 deletions(-) diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index cb4448b8..18acf7f8 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -5,7 +5,6 @@ use std::ffi::{CStr, CString}; use std::fmt; use std::io; use std::io::prelude::*; -use std::iter; use std::mem; use std::net; use std::path::Path; @@ -740,97 +739,57 @@ make_LibSslError! { ErrorWantAccept = SSL_ERROR_WANT_ACCEPT } -/// A stream wrapper which handles SSL encryption for an underlying stream. -#[derive(Clone)] -pub struct SslStream { +struct IndirectStream { stream: S, ssl: Arc, - buf: Vec + // Max TLS record size is 16k + buf: Box<[u8; 16 * 1024]>, } -impl SslStream { - /// Create a new independently owned handle to the underlying socket. - pub fn try_clone(&self) -> io::Result> { - Ok(SslStream { +impl Clone for IndirectStream { + fn clone(&self) -> IndirectStream { + IndirectStream { + stream: self.stream.clone(), + ssl: self.ssl.clone(), + buf: Box::new(*self.buf) + } + } +} + +impl IndirectStream { + fn try_clone(&self) -> io::Result> { + Ok(IndirectStream { stream: try!(self.stream.try_clone()), ssl: self.ssl.clone(), - buf: self.buf.clone(), + buf: Box::new(*self.buf) }) } } -impl fmt::Debug for SslStream where S: fmt::Debug { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.stream, self.ssl) - } -} - -impl SslStream { - fn new_base(ssl: T, stream: S) -> Result, SslError> { +impl IndirectStream { + fn new_base(ssl: T, stream: S) -> Result, SslError> { let ssl = try!(ssl.into_ssl()); - Ok(SslStream { + Ok(IndirectStream { stream: stream, ssl: Arc::new(ssl), - // Maximum TLS record size is 16k - buf: iter::repeat(0).take(16 * 1024).collect(), + buf: Box::new([0; 16 * 1024]), }) } - pub fn new_client(ssl: T, stream: S) -> Result, SslError> { - let mut ssl = try!(SslStream::new_base(ssl, stream)); + fn new_client(ssl: T, stream: S) -> Result, SslError> { + let mut ssl = try!(IndirectStream::new_base(ssl, stream)); try!(ssl.in_retry_wrapper(|ssl| ssl.connect())); Ok(ssl) } - pub fn new_server(ssl: T, stream: S) -> Result, SslError> { - let mut ssl = try!(SslStream::new_base(ssl, stream)); + fn new_server(ssl: T, stream: S) -> Result, SslError> { + let mut ssl = try!(IndirectStream::new_base(ssl, stream)); try!(ssl.in_retry_wrapper(|ssl| ssl.accept())); Ok(ssl) } - /// # Deprecated - pub fn new_server_from(ssl: Ssl, stream: S) -> Result, SslError> { - SslStream::new_server(ssl, stream) - } - - /// # Deprecated - pub fn new_from(ssl: Ssl, stream: S) -> Result, SslError> { - SslStream::new_client(ssl, stream) - } - - /// # Deprecated - pub fn new(ctx: &SslContext, stream: S) -> Result, SslError> { - SslStream::new_client(ctx, stream) - } - - /// # Deprecated - #[doc(hidden)] - pub fn get_inner(&mut self) -> &mut S { - self.get_mut() - } - - /// Returns a reference to the underlying stream. - pub fn get_ref(&self) -> &S { - &self.stream - } - - /// Return the certificate of the peer - pub fn get_peer_certificate(&self) -> Option { - self.ssl.get_peer_certificate() - } - - /// Returns a mutable reference to the underlying stream. - /// - /// ## Warning - /// - /// It is inadvisable to read from or write to the underlying stream as it - /// will most likely corrupt the SSL session. - pub fn get_mut(&mut self) -> &mut S { - &mut self.stream - } - - fn in_retry_wrapper(&mut self, mut blk: F) - -> Result where F: FnMut(&Ssl) -> c_int { + fn in_retry_wrapper(&mut self, mut blk: F) -> Result + where F: FnMut(&Ssl) -> c_int { loop { let ret = blk(&self.ssl); if ret > 0 { @@ -860,12 +819,149 @@ impl SslStream { fn write_through(&mut self) -> io::Result<()> { io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ()) } +} + +impl Read for IndirectStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { + Ok(len) => Ok(len as usize), + Err(SslSessionClosed) => Ok(0), + Err(StreamError(e)) => Err(e), + Err(e @ OpenSslErrors(_)) => { + Err(io::Error::new(io::ErrorKind::Other, e)) + } + } + } +} + +impl Write for IndirectStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { + Ok(len) => len as usize, + Err(SslSessionClosed) => 0, + Err(StreamError(e)) => return Err(e), + Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)), + }; + try!(self.write_through()); + Ok(count) + } + + fn flush(&mut self) -> io::Result<()> { + try!(self.write_through()); + self.stream.flush() + } +} + +#[derive(Clone)] +enum StreamKind { + Indirect(IndirectStream), +} + +impl StreamKind { + fn stream(&self) -> &S { + match *self { + StreamKind::Indirect(ref s) => &s.stream + } + } + + fn mut_stream(&mut self) -> &mut S { + match *self { + StreamKind::Indirect(ref mut s) => &mut s.stream + } + } + + fn ssl(&self) -> &Ssl { + match *self { + StreamKind::Indirect(ref s) => &s.ssl + } + } +} + +/// A stream wrapper which handles SSL encryption for an underlying stream. +#[derive(Clone)] +pub struct SslStream { + kind: StreamKind, +} + +impl SslStream { + /// Create a new independently owned handle to the underlying socket. + pub fn try_clone(&self) -> io::Result> { + let kind = match self.kind { + StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())) + }; + Ok(SslStream { + kind: kind + }) + } +} + +impl fmt::Debug for SslStream where S: fmt::Debug { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.kind.stream(), self.kind.ssl()) + } +} + +impl SslStream { + pub fn new_client(ssl: T, stream: S) -> Result, SslError> { + let stream = try!(IndirectStream::new_client(ssl, stream)); + Ok(SslStream { + kind: StreamKind::Indirect(stream) + }) + } + + pub fn new_server(ssl: T, stream: S) -> Result, SslError> { + let stream = try!(IndirectStream::new_server(ssl, stream)); + Ok(SslStream { + kind: StreamKind::Indirect(stream) + }) + } + + /// # Deprecated + pub fn new_server_from(ssl: Ssl, stream: S) -> Result, SslError> { + SslStream::new_server(ssl, stream) + } + + /// # Deprecated + pub fn new_from(ssl: Ssl, stream: S) -> Result, SslError> { + SslStream::new_client(ssl, stream) + } + + /// # Deprecated + pub fn new(ctx: &SslContext, stream: S) -> Result, SslError> { + SslStream::new_client(ctx, stream) + } + + /// # Deprecated + #[doc(hidden)] + pub fn get_inner(&mut self) -> &mut S { + self.get_mut() + } + + /// Returns a reference to the underlying stream. + pub fn get_ref(&self) -> &S { + self.kind.stream() + } + + /// Return the certificate of the peer + pub fn get_peer_certificate(&self) -> Option { + self.kind.ssl().get_peer_certificate() + } + + /// Returns a mutable reference to the underlying stream. + /// + /// ## Warning + /// + /// It is inadvisable to read from or write to the underlying stream as it + /// will most likely corrupt the SSL session. + pub fn get_mut(&mut self) -> &mut S { + self.kind.mut_stream() + } /// Get the compression currently in use. The result will be /// either None, indicating no compression is in use, or a string /// with the compression name. pub fn get_compression(&self) -> Option { - let ptr = unsafe { ffi::SSL_get_current_compression(self.ssl.ssl) }; + let ptr = unsafe { ffi::SSL_get_current_compression(self.kind.ssl().ssl) }; if ptr == ptr::null() { return None; } @@ -886,43 +982,34 @@ impl SslStream { /// This method needs the `npn` feature. #[cfg(feature = "npn")] pub fn get_selected_npn_protocol(&self) -> Option<&[u8]> { - self.ssl.get_selected_npn_protocol() + self.kind.ssl().get_selected_npn_protocol() } /// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any). pub fn pending(&self) -> usize { - self.ssl.pending() + self.kind.ssl().pending() } } impl Read for SslStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { - Ok(len) => Ok(len as usize), - Err(SslSessionClosed) => Ok(0), - Err(StreamError(e)) => Err(e), - Err(e @ OpenSslErrors(_)) => { - Err(io::Error::new(io::ErrorKind::Other, e)) - } + match self.kind { + StreamKind::Indirect(ref mut s) => s.read(buf) } } } impl Write for SslStream { fn write(&mut self, buf: &[u8]) -> io::Result { - let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { - Ok(len) => len as usize, - Err(SslSessionClosed) => 0, - Err(StreamError(e)) => return Err(e), - Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)), - }; - try!(self.write_through()); - Ok(count) + match self.kind { + StreamKind::Indirect(ref mut s) => s.write(buf) + } } fn flush(&mut self) -> io::Result<()> { - try!(self.write_through()); - self.stream.flush() + match self.kind { + StreamKind::Indirect(ref mut s) => s.flush() + } } }