Prepare for direct stream support
This commit is contained in:
parent
c722f889c1
commit
9b235a7b91
|
|
@ -5,7 +5,6 @@ use std::ffi::{CStr, CString};
|
|||
use std::fmt;
|
||||
use std::io;
|
||||
use std::io::prelude::*;
|
||||
use std::iter;
|
||||
use std::mem;
|
||||
use std::net;
|
||||
use std::path::Path;
|
||||
|
|
@ -740,97 +739,57 @@ make_LibSslError! {
|
|||
ErrorWantAccept = SSL_ERROR_WANT_ACCEPT
|
||||
}
|
||||
|
||||
/// A stream wrapper which handles SSL encryption for an underlying stream.
|
||||
#[derive(Clone)]
|
||||
pub struct SslStream<S> {
|
||||
struct IndirectStream<S> {
|
||||
stream: S,
|
||||
ssl: Arc<Ssl>,
|
||||
buf: Vec<u8>
|
||||
// Max TLS record size is 16k
|
||||
buf: Box<[u8; 16 * 1024]>,
|
||||
}
|
||||
|
||||
impl SslStream<net::TcpStream> {
|
||||
/// Create a new independently owned handle to the underlying socket.
|
||||
pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> {
|
||||
Ok(SslStream {
|
||||
impl<S: Clone> Clone for IndirectStream<S> {
|
||||
fn clone(&self) -> IndirectStream<S> {
|
||||
IndirectStream {
|
||||
stream: self.stream.clone(),
|
||||
ssl: self.ssl.clone(),
|
||||
buf: Box::new(*self.buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IndirectStream<net::TcpStream> {
|
||||
fn try_clone(&self) -> io::Result<IndirectStream<net::TcpStream>> {
|
||||
Ok(IndirectStream {
|
||||
stream: try!(self.stream.try_clone()),
|
||||
ssl: self.ssl.clone(),
|
||||
buf: self.buf.clone(),
|
||||
buf: Box::new(*self.buf)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.stream, self.ssl)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Read+Write> SslStream<S> {
|
||||
fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
impl<S: Read+Write> IndirectStream<S> {
|
||||
fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
|
||||
let ssl = try!(ssl.into_ssl());
|
||||
Ok(SslStream {
|
||||
Ok(IndirectStream {
|
||||
stream: stream,
|
||||
ssl: Arc::new(ssl),
|
||||
// Maximum TLS record size is 16k
|
||||
buf: iter::repeat(0).take(16 * 1024).collect(),
|
||||
buf: Box::new([0; 16 * 1024]),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
let mut ssl = try!(SslStream::new_base(ssl, stream));
|
||||
fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
|
||||
let mut ssl = try!(IndirectStream::new_base(ssl, stream));
|
||||
try!(ssl.in_retry_wrapper(|ssl| ssl.connect()));
|
||||
Ok(ssl)
|
||||
}
|
||||
|
||||
pub fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
let mut ssl = try!(SslStream::new_base(ssl, stream));
|
||||
fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
|
||||
let mut ssl = try!(IndirectStream::new_base(ssl, stream));
|
||||
try!(ssl.in_retry_wrapper(|ssl| ssl.accept()));
|
||||
Ok(ssl)
|
||||
}
|
||||
|
||||
/// # Deprecated
|
||||
pub fn new_server_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
SslStream::new_server(ssl, stream)
|
||||
}
|
||||
|
||||
/// # Deprecated
|
||||
pub fn new_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
SslStream::new_client(ssl, stream)
|
||||
}
|
||||
|
||||
/// # Deprecated
|
||||
pub fn new(ctx: &SslContext, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
SslStream::new_client(ctx, stream)
|
||||
}
|
||||
|
||||
/// # Deprecated
|
||||
#[doc(hidden)]
|
||||
pub fn get_inner(&mut self) -> &mut S {
|
||||
self.get_mut()
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
|
||||
/// Return the certificate of the peer
|
||||
pub fn get_peer_certificate(&self) -> Option<X509> {
|
||||
self.ssl.get_peer_certificate()
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the underlying stream.
|
||||
///
|
||||
/// ## Warning
|
||||
///
|
||||
/// It is inadvisable to read from or write to the underlying stream as it
|
||||
/// will most likely corrupt the SSL session.
|
||||
pub fn get_mut(&mut self) -> &mut S {
|
||||
&mut self.stream
|
||||
}
|
||||
|
||||
fn in_retry_wrapper<F>(&mut self, mut blk: F)
|
||||
-> Result<c_int, SslError> where F: FnMut(&Ssl) -> c_int {
|
||||
fn in_retry_wrapper<F>(&mut self, mut blk: F) -> Result<c_int, SslError>
|
||||
where F: FnMut(&Ssl) -> c_int {
|
||||
loop {
|
||||
let ret = blk(&self.ssl);
|
||||
if ret > 0 {
|
||||
|
|
@ -860,12 +819,149 @@ impl<S: Read+Write> SslStream<S> {
|
|||
fn write_through(&mut self) -> io::Result<()> {
|
||||
io::copy(&mut *self.ssl.get_wbio(), &mut self.stream).map(|_| ())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Read+Write> Read for IndirectStream<S> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
|
||||
Ok(len) => Ok(len as usize),
|
||||
Err(SslSessionClosed) => Ok(0),
|
||||
Err(StreamError(e)) => Err(e),
|
||||
Err(e @ OpenSslErrors(_)) => {
|
||||
Err(io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Read+Write> Write for IndirectStream<S> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) {
|
||||
Ok(len) => len as usize,
|
||||
Err(SslSessionClosed) => 0,
|
||||
Err(StreamError(e)) => return Err(e),
|
||||
Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
|
||||
};
|
||||
try!(self.write_through());
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
try!(self.write_through());
|
||||
self.stream.flush()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum StreamKind<S> {
|
||||
Indirect(IndirectStream<S>),
|
||||
}
|
||||
|
||||
impl<S> StreamKind<S> {
|
||||
fn stream(&self) -> &S {
|
||||
match *self {
|
||||
StreamKind::Indirect(ref s) => &s.stream
|
||||
}
|
||||
}
|
||||
|
||||
fn mut_stream(&mut self) -> &mut S {
|
||||
match *self {
|
||||
StreamKind::Indirect(ref mut s) => &mut s.stream
|
||||
}
|
||||
}
|
||||
|
||||
fn ssl(&self) -> &Ssl {
|
||||
match *self {
|
||||
StreamKind::Indirect(ref s) => &s.ssl
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream wrapper which handles SSL encryption for an underlying stream.
|
||||
#[derive(Clone)]
|
||||
pub struct SslStream<S> {
|
||||
kind: StreamKind<S>,
|
||||
}
|
||||
|
||||
impl SslStream<net::TcpStream> {
|
||||
/// Create a new independently owned handle to the underlying socket.
|
||||
pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> {
|
||||
let kind = match self.kind {
|
||||
StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone()))
|
||||
};
|
||||
Ok(SslStream {
|
||||
kind: kind
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(fmt, "SslStream {{ stream: {:?}, ssl: {:?} }}", self.kind.stream(), self.kind.ssl())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Read+Write> SslStream<S> {
|
||||
pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
let stream = try!(IndirectStream::new_client(ssl, stream));
|
||||
Ok(SslStream {
|
||||
kind: StreamKind::Indirect(stream)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_server<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
let stream = try!(IndirectStream::new_server(ssl, stream));
|
||||
Ok(SslStream {
|
||||
kind: StreamKind::Indirect(stream)
|
||||
})
|
||||
}
|
||||
|
||||
/// # Deprecated
|
||||
pub fn new_server_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
SslStream::new_server(ssl, stream)
|
||||
}
|
||||
|
||||
/// # Deprecated
|
||||
pub fn new_from(ssl: Ssl, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
SslStream::new_client(ssl, stream)
|
||||
}
|
||||
|
||||
/// # Deprecated
|
||||
pub fn new(ctx: &SslContext, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
SslStream::new_client(ctx, stream)
|
||||
}
|
||||
|
||||
/// # Deprecated
|
||||
#[doc(hidden)]
|
||||
pub fn get_inner(&mut self) -> &mut S {
|
||||
self.get_mut()
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
self.kind.stream()
|
||||
}
|
||||
|
||||
/// Return the certificate of the peer
|
||||
pub fn get_peer_certificate(&self) -> Option<X509> {
|
||||
self.kind.ssl().get_peer_certificate()
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the underlying stream.
|
||||
///
|
||||
/// ## Warning
|
||||
///
|
||||
/// It is inadvisable to read from or write to the underlying stream as it
|
||||
/// will most likely corrupt the SSL session.
|
||||
pub fn get_mut(&mut self) -> &mut S {
|
||||
self.kind.mut_stream()
|
||||
}
|
||||
|
||||
/// Get the compression currently in use. The result will be
|
||||
/// either None, indicating no compression is in use, or a string
|
||||
/// with the compression name.
|
||||
pub fn get_compression(&self) -> Option<String> {
|
||||
let ptr = unsafe { ffi::SSL_get_current_compression(self.ssl.ssl) };
|
||||
let ptr = unsafe { ffi::SSL_get_current_compression(self.kind.ssl().ssl) };
|
||||
if ptr == ptr::null() {
|
||||
return None;
|
||||
}
|
||||
|
|
@ -886,43 +982,34 @@ impl<S: Read+Write> SslStream<S> {
|
|||
/// This method needs the `npn` feature.
|
||||
#[cfg(feature = "npn")]
|
||||
pub fn get_selected_npn_protocol(&self) -> Option<&[u8]> {
|
||||
self.ssl.get_selected_npn_protocol()
|
||||
self.kind.ssl().get_selected_npn_protocol()
|
||||
}
|
||||
|
||||
/// pending() takes into account only bytes from the TLS/SSL record that is currently being processed (if any).
|
||||
pub fn pending(&self) -> usize {
|
||||
self.ssl.pending()
|
||||
self.kind.ssl().pending()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Read+Write> Read for SslStream<S> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
|
||||
Ok(len) => Ok(len as usize),
|
||||
Err(SslSessionClosed) => Ok(0),
|
||||
Err(StreamError(e)) => Err(e),
|
||||
Err(e @ OpenSslErrors(_)) => {
|
||||
Err(io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
match self.kind {
|
||||
StreamKind::Indirect(ref mut s) => s.read(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Read+Write> Write for SslStream<S> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) {
|
||||
Ok(len) => len as usize,
|
||||
Err(SslSessionClosed) => 0,
|
||||
Err(StreamError(e)) => return Err(e),
|
||||
Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
|
||||
};
|
||||
try!(self.write_through());
|
||||
Ok(count)
|
||||
match self.kind {
|
||||
StreamKind::Indirect(ref mut s) => s.write(buf)
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
try!(self.write_through());
|
||||
self.stream.flush()
|
||||
match self.kind {
|
||||
StreamKind::Indirect(ref mut s) => s.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue