Prepare for direct stream support

This commit is contained in:
Steven Fackler 2015-06-27 22:37:10 -07:00
parent c722f889c1
commit 9b235a7b91
1 changed files with 175 additions and 88 deletions

View File

@ -5,7 +5,6 @@ use std::ffi::{CStr, CString};
use std::fmt;
use std::io;
use std::io::prelude::*;
use std::iter;
use std::mem;
use std::net;
use std::path::Path;
@ -740,97 +739,57 @@ make_LibSslError! {
ErrorWantAccept = SSL_ERROR_WANT_ACCEPT
}
/// A stream wrapper which handles SSL encryption for an underlying stream.
#[derive(Clone)]
pub struct SslStream<S> {
struct IndirectStream<S> {
stream: S,
ssl: Arc<Ssl>,
buf: Vec<u8>
// Max TLS record size is 16k
buf: Box<[u8; 16 * 1024]>,
}
impl SslStream<net::TcpStream> {
/// Create a new independently owned handle to the underlying socket.
pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> {
Ok(SslStream {
impl<S: Clone> Clone for IndirectStream<S> {
fn clone(&self) -> IndirectStream<S> {
IndirectStream {
stream: self.stream.clone(),
ssl: self.ssl.clone(),
buf: Box::new(*self.buf)
}
}
}
impl IndirectStream<net::TcpStream> {
fn try_clone(&self) -> io::Result<IndirectStream<net::TcpStream>> {
Ok(IndirectStream {
stream: try!(self.stream.try_clone()),
ssl: self.ssl.clone(),
buf: self.buf.clone(),
buf: Box::new(*self.buf)
})
}
}
impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.stream, self.ssl)
}
}
impl<S: Read+Write> SslStream<S> {
fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
impl<S: Read+Write> IndirectStream<S> {
fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
let ssl = try!(ssl.into_ssl());
Ok(SslStream {
Ok(IndirectStream {
stream: stream,
ssl: Arc::new(ssl),
// Maximum TLS record size is 16k
buf: iter::repeat(0).take(16 * 1024).collect(),
buf: Box::new([0; 16 * 1024]),
})
}
pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let mut ssl = try!(SslStream::new_base(ssl, stream));
fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
let mut ssl = try!(IndirectStream::new_base(ssl, stream));
try!(ssl.in_retry_wrapper(|ssl| ssl.connect()));
Ok(ssl)
}
pub fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let mut ssl = try!(SslStream::new_base(ssl, stream));
fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
let mut ssl = try!(IndirectStream::new_base(ssl, stream));
try!(ssl.in_retry_wrapper(|ssl| ssl.accept()));
Ok(ssl)
}
/// # Deprecated
pub fn new_server_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> {
SslStream::new_server(ssl, stream)
}
/// # Deprecated
pub fn new_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> {
SslStream::new_client(ssl, stream)
}
/// # Deprecated
pub fn new(ctx: &SslContext, stream: S) -> Result<SslStream<S>, SslError> {
SslStream::new_client(ctx, stream)
}
/// # Deprecated
#[doc(hidden)]
pub fn get_inner(&mut self) -> &mut S {
self.get_mut()
}
/// Returns a reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Return the certificate of the peer
pub fn get_peer_certificate(&self) -> Option<X509> {
self.ssl.get_peer_certificate()
}
/// Returns a mutable reference to the underlying stream.
///
/// ## Warning
///
/// It is inadvisable to read from or write to the underlying stream as it
/// will most likely corrupt the SSL session.
pub fn get_mut(&mut self) -> &mut S {
&mut self.stream
}
fn in_retry_wrapper<F>(&mut self, mut blk: F)
-> Result<c_int, SslError> where F: FnMut(&Ssl) -> c_int {
fn in_retry_wrapper<F>(&mut self, mut blk: F) -> Result<c_int, SslError>
where F: FnMut(&Ssl) -> c_int {
loop {
let ret = blk(&self.ssl);
if ret > 0 {
@ -860,12 +819,149 @@ impl<S: Read+Write> SslStream<S> {
fn write_through(&mut self) -> io::Result<()> {
io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ())
}
}
impl<S: Read+Write> Read for IndirectStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
Ok(len) => Ok(len as usize),
Err(SslSessionClosed) => Ok(0),
Err(StreamError(e)) => Err(e),
Err(e @ OpenSslErrors(_)) => {
Err(io::Error::new(io::ErrorKind::Other, e))
}
}
}
}
impl<S: Read+Write> Write for IndirectStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) {
Ok(len) => len as usize,
Err(SslSessionClosed) => 0,
Err(StreamError(e)) => return Err(e),
Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
};
try!(self.write_through());
Ok(count)
}
fn flush(&mut self) -> io::Result<()> {
try!(self.write_through());
self.stream.flush()
}
}
#[derive(Clone)]
enum StreamKind<S> {
Indirect(IndirectStream<S>),
}
impl<S> StreamKind<S> {
fn stream(&self) -> &S {
match *self {
StreamKind::Indirect(ref s) => &s.stream
}
}
fn mut_stream(&mut self) -> &mut S {
match *self {
StreamKind::Indirect(ref mut s) => &mut s.stream
}
}
fn ssl(&self) -> &Ssl {
match *self {
StreamKind::Indirect(ref s) => &s.ssl
}
}
}
/// A stream wrapper which handles SSL encryption for an underlying stream.
#[derive(Clone)]
pub struct SslStream<S> {
kind: StreamKind<S>,
}
impl SslStream<net::TcpStream> {
/// Create a new independently owned handle to the underlying socket.
pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> {
let kind = match self.kind {
StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone()))
};
Ok(SslStream {
kind: kind
})
}
}
impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.kind.stream(), self.kind.ssl())
}
}
impl<S: Read+Write> SslStream<S> {
pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let stream = try!(IndirectStream::new_client(ssl, stream));
Ok(SslStream {
kind: StreamKind::Indirect(stream)
})
}
pub fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let stream = try!(IndirectStream::new_server(ssl, stream));
Ok(SslStream {
kind: StreamKind::Indirect(stream)
})
}
/// # Deprecated
pub fn new_server_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> {
SslStream::new_server(ssl, stream)
}
/// # Deprecated
pub fn new_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> {
SslStream::new_client(ssl, stream)
}
/// # Deprecated
pub fn new(ctx: &SslContext, stream: S) -> Result<SslStream<S>, SslError> {
SslStream::new_client(ctx, stream)
}
/// # Deprecated
#[doc(hidden)]
pub fn get_inner(&mut self) -> &mut S {
self.get_mut()
}
/// Returns a reference to the underlying stream.
pub fn get_ref(&self) -> &S {
self.kind.stream()
}
/// Return the certificate of the peer
pub fn get_peer_certificate(&self) -> Option<X509> {
self.kind.ssl().get_peer_certificate()
}
/// Returns a mutable reference to the underlying stream.
///
/// ## Warning
///
/// It is inadvisable to read from or write to the underlying stream as it
/// will most likely corrupt the SSL session.
pub fn get_mut(&mut self) -> &mut S {
self.kind.mut_stream()
}
/// Get the compression currently in use. The result will be
/// either None, indicating no compression is in use, or a string
/// with the compression name.
pub fn get_compression(&self) -> Option<String> {
let ptr = unsafe { ffi::SSL_get_current_compression(self.ssl.ssl) };
let ptr = unsafe { ffi::SSL_get_current_compression(self.kind.ssl().ssl) };
if ptr == ptr::null() {
return None;
}
@ -886,43 +982,34 @@ impl<S: Read+Write> SslStream<S> {
/// This method needs the `npn` feature.
#[cfg(feature = "npn")]
pub fn get_selected_npn_protocol(&self) -> Option<&[u8]> {
self.ssl.get_selected_npn_protocol()
self.kind.ssl().get_selected_npn_protocol()
}
/// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any).
pub fn pending(&self) -> usize {
self.ssl.pending()
self.kind.ssl().pending()
}
}
impl<S: Read+Write> Read for SslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
Ok(len) => Ok(len as usize),
Err(SslSessionClosed) => Ok(0),
Err(StreamError(e)) => Err(e),
Err(e @ OpenSslErrors(_)) => {
Err(io::Error::new(io::ErrorKind::Other, e))
}
match self.kind {
StreamKind::Indirect(ref mut s) => s.read(buf)
}
}
}
impl<S: Read+Write> Write for SslStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) {
Ok(len) => len as usize,
Err(SslSessionClosed) => 0,
Err(StreamError(e)) => return Err(e),
Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
};
try!(self.write_through());
Ok(count)
match self.kind {
StreamKind::Indirect(ref mut s) => s.write(buf)
}
}
fn flush(&mut self) -> io::Result<()> {
try!(self.write_through());
self.stream.flush()
match self.kind {
StreamKind::Indirect(ref mut s) => s.flush()
}
}
}