Prepare for direct stream support

This commit is contained in:
Steven Fackler 2015-06-27 22:37:10 -07:00
parent c722f889c1
commit 9b235a7b91
1 changed files with 175 additions and 88 deletions

View File

@ -5,7 +5,6 @@ use std::ffi::{CStr, CString};
use std::fmt; use std::fmt;
use std::io; use std::io;
use std::io::prelude::*; use std::io::prelude::*;
use std::iter;
use std::mem; use std::mem;
use std::net; use std::net;
use std::path::Path; use std::path::Path;
@ -740,97 +739,57 @@ make_LibSslError! {
ErrorWantAccept = SSL_ERROR_WANT_ACCEPT ErrorWantAccept = SSL_ERROR_WANT_ACCEPT
} }
/// A stream wrapper which handles SSL encryption for an underlying stream. struct IndirectStream<S> {
#[derive(Clone)]
pub struct SslStream<S> {
stream: S, stream: S,
ssl: Arc<Ssl>, ssl: Arc<Ssl>,
buf: Vec<u8> // Max TLS record size is 16k
buf: Box<[u8; 16 * 1024]>,
} }
impl SslStream<net::TcpStream> { impl<S: Clone> Clone for IndirectStream<S> {
/// Create a new independently owned handle to the underlying socket. fn clone(&self) -> IndirectStream<S> {
pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> { IndirectStream {
Ok(SslStream { stream: self.stream.clone(),
ssl: self.ssl.clone(),
buf: Box::new(*self.buf)
}
}
}
impl IndirectStream<net::TcpStream> {
fn try_clone(&self) -> io::Result<IndirectStream<net::TcpStream>> {
Ok(IndirectStream {
stream: try!(self.stream.try_clone()), stream: try!(self.stream.try_clone()),
ssl: self.ssl.clone(), ssl: self.ssl.clone(),
buf: self.buf.clone(), buf: Box::new(*self.buf)
}) })
} }
} }
impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug { impl<S: Read+Write> IndirectStream<S> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.stream, self.ssl)
}
}
impl<S: Read+Write> SslStream<S> {
fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let ssl = try!(ssl.into_ssl()); let ssl = try!(ssl.into_ssl());
Ok(SslStream { Ok(IndirectStream {
stream: stream, stream: stream,
ssl: Arc::new(ssl), ssl: Arc::new(ssl),
// Maximum TLS record size is 16k buf: Box::new([0; 16 * 1024]),
buf: iter::repeat(0).take(16 * 1024).collect(),
}) })
} }
pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
let mut ssl = try!(SslStream::new_base(ssl, stream)); let mut ssl = try!(IndirectStream::new_base(ssl, stream));
try!(ssl.in_retry_wrapper(|ssl| ssl.connect())); try!(ssl.in_retry_wrapper(|ssl| ssl.connect()));
Ok(ssl) Ok(ssl)
} }
pub fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
let mut ssl = try!(SslStream::new_base(ssl, stream)); let mut ssl = try!(IndirectStream::new_base(ssl, stream));
try!(ssl.in_retry_wrapper(|ssl| ssl.accept())); try!(ssl.in_retry_wrapper(|ssl| ssl.accept()));
Ok(ssl) Ok(ssl)
} }
/// # Deprecated fn in_retry_wrapper<F>(&mut self, mut blk: F) -> Result<c_int, SslError>
pub fn new_server_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> { where F: FnMut(&Ssl) -> c_int {
SslStream::new_server(ssl, stream)
}
/// # Deprecated
pub fn new_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> {
SslStream::new_client(ssl, stream)
}
/// # Deprecated
pub fn new(ctx: &SslContext, stream: S) -> Result<SslStream<S>, SslError> {
SslStream::new_client(ctx, stream)
}
/// # Deprecated
#[doc(hidden)]
pub fn get_inner(&mut self) -> &mut S {
self.get_mut()
}
/// Returns a reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Return the certificate of the peer
pub fn get_peer_certificate(&self) -> Option<X509> {
self.ssl.get_peer_certificate()
}
/// Returns a mutable reference to the underlying stream.
///
/// ## Warning
///
/// It is inadvisable to read from or write to the underlying stream as it
/// will most likely corrupt the SSL session.
pub fn get_mut(&mut self) -> &mut S {
&mut self.stream
}
fn in_retry_wrapper<F>(&mut self, mut blk: F)
-> Result<c_int, SslError> where F: FnMut(&Ssl) -> c_int {
loop { loop {
let ret = blk(&self.ssl); let ret = blk(&self.ssl);
if ret > 0 { if ret > 0 {
@ -860,12 +819,149 @@ impl<S: Read+Write> SslStream<S> {
fn write_through(&mut self) -> io::Result<()> { fn write_through(&mut self) -> io::Result<()> {
io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ()) io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ())
} }
}
impl<S: Read+Write> Read for IndirectStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
Ok(len) => Ok(len as usize),
Err(SslSessionClosed) => Ok(0),
Err(StreamError(e)) => Err(e),
Err(e @ OpenSslErrors(_)) => {
Err(io::Error::new(io::ErrorKind::Other, e))
}
}
}
}
impl<S: Read+Write> Write for IndirectStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) {
Ok(len) => len as usize,
Err(SslSessionClosed) => 0,
Err(StreamError(e)) => return Err(e),
Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
};
try!(self.write_through());
Ok(count)
}
fn flush(&mut self) -> io::Result<()> {
try!(self.write_through());
self.stream.flush()
}
}
#[derive(Clone)]
enum StreamKind<S> {
Indirect(IndirectStream<S>),
}
impl<S> StreamKind<S> {
fn stream(&self) -> &S {
match *self {
StreamKind::Indirect(ref s) => &s.stream
}
}
fn mut_stream(&mut self) -> &mut S {
match *self {
StreamKind::Indirect(ref mut s) => &mut s.stream
}
}
fn ssl(&self) -> &Ssl {
match *self {
StreamKind::Indirect(ref s) => &s.ssl
}
}
}
/// A stream wrapper which handles SSL encryption for an underlying stream.
#[derive(Clone)]
pub struct SslStream<S> {
kind: StreamKind<S>,
}
impl SslStream<net::TcpStream> {
/// Create a new independently owned handle to the underlying socket.
pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> {
let kind = match self.kind {
StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone()))
};
Ok(SslStream {
kind: kind
})
}
}
impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.kind.stream(), self.kind.ssl())
}
}
impl<S: Read+Write> SslStream<S> {
pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let stream = try!(IndirectStream::new_client(ssl, stream));
Ok(SslStream {
kind: StreamKind::Indirect(stream)
})
}
pub fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let stream = try!(IndirectStream::new_server(ssl, stream));
Ok(SslStream {
kind: StreamKind::Indirect(stream)
})
}
/// # Deprecated
pub fn new_server_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> {
SslStream::new_server(ssl, stream)
}
/// # Deprecated
pub fn new_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> {
SslStream::new_client(ssl, stream)
}
/// # Deprecated
pub fn new(ctx: &SslContext, stream: S) -> Result<SslStream<S>, SslError> {
SslStream::new_client(ctx, stream)
}
/// # Deprecated
#[doc(hidden)]
pub fn get_inner(&mut self) -> &mut S {
self.get_mut()
}
/// Returns a reference to the underlying stream.
pub fn get_ref(&self) -> &S {
self.kind.stream()
}
/// Return the certificate of the peer
pub fn get_peer_certificate(&self) -> Option<X509> {
self.kind.ssl().get_peer_certificate()
}
/// Returns a mutable reference to the underlying stream.
///
/// ## Warning
///
/// It is inadvisable to read from or write to the underlying stream as it
/// will most likely corrupt the SSL session.
pub fn get_mut(&mut self) -> &mut S {
self.kind.mut_stream()
}
/// Get the compression currently in use. The result will be /// Get the compression currently in use. The result will be
/// either None, indicating no compression is in use, or a string /// either None, indicating no compression is in use, or a string
/// with the compression name. /// with the compression name.
pub fn get_compression(&self) -> Option<String> { pub fn get_compression(&self) -> Option<String> {
let ptr = unsafe { ffi::SSL_get_current_compression(self.ssl.ssl) }; let ptr = unsafe { ffi::SSL_get_current_compression(self.kind.ssl().ssl) };
if ptr == ptr::null() { if ptr == ptr::null() {
return None; return None;
} }
@ -886,43 +982,34 @@ impl<S: Read+Write> SslStream<S> {
/// This method needs the `npn` feature. /// This method needs the `npn` feature.
#[cfg(feature = "npn")] #[cfg(feature = "npn")]
pub fn get_selected_npn_protocol(&self) -> Option<&[u8]> { pub fn get_selected_npn_protocol(&self) -> Option<&[u8]> {
self.ssl.get_selected_npn_protocol() self.kind.ssl().get_selected_npn_protocol()
} }
/// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any). /// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any).
pub fn pending(&self) -> usize { pub fn pending(&self) -> usize {
self.ssl.pending() self.kind.ssl().pending()
} }
} }
impl<S: Read+Write> Read for SslStream<S> { impl<S: Read+Write> Read for SslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { match self.kind {
Ok(len) => Ok(len as usize), StreamKind::Indirect(ref mut s) => s.read(buf)
Err(SslSessionClosed) => Ok(0),
Err(StreamError(e)) => Err(e),
Err(e @ OpenSslErrors(_)) => {
Err(io::Error::new(io::ErrorKind::Other, e))
}
} }
} }
} }
impl<S: Read+Write> Write for SslStream<S> { impl<S: Read+Write> Write for SslStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) { match self.kind {
Ok(len) => len as usize, StreamKind::Indirect(ref mut s) => s.write(buf)
Err(SslSessionClosed) => 0, }
Err(StreamError(e)) => return Err(e),
Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
};
try!(self.write_through());
Ok(count)
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
try!(self.write_through()); match self.kind {
self.stream.flush() StreamKind::Indirect(ref mut s) => s.flush()
}
} }
} }