Have NonblockingSslStream delegate to SslStream
This commit is contained in:
parent
1df131ff81
commit
d6ce9afdf3
|
|
@ -95,6 +95,11 @@ impl OpenSslError {
|
||||||
errs
|
errs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the raw OpenSSL error code for this error.
|
||||||
|
pub fn error_code(&self) -> c_ulong {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the name of the library reporting the error.
|
/// Returns the name of the library reporting the error.
|
||||||
pub fn library(&self) -> &'static str {
|
pub fn library(&self) -> &'static str {
|
||||||
get_lib(self.0)
|
get_lib(self.0)
|
||||||
|
|
@ -239,6 +244,17 @@ pub enum OpensslError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl OpensslError {
|
||||||
|
pub fn from_error_code(err: c_ulong) -> OpensslError {
|
||||||
|
ffi::init();
|
||||||
|
UnknownError {
|
||||||
|
library: get_lib(err).to_owned(),
|
||||||
|
function: get_func(err).to_owned(),
|
||||||
|
reason: get_reason(err).to_owned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn get_lib(err: c_ulong) -> &'static str {
|
fn get_lib(err: c_ulong) -> &'static str {
|
||||||
unsafe {
|
unsafe {
|
||||||
let cstr = ffi::ERR_lib_error_string(err);
|
let cstr = ffi::ERR_lib_error_string(err);
|
||||||
|
|
@ -271,7 +287,7 @@ impl SslError {
|
||||||
loop {
|
loop {
|
||||||
match unsafe { ffi::ERR_get_error() } {
|
match unsafe { ffi::ERR_get_error() } {
|
||||||
0 => break,
|
0 => break,
|
||||||
err => errs.push(SslError::from_error_code(err))
|
err => errs.push(OpensslError::from_error_code(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
OpenSslErrors(errs)
|
OpenSslErrors(errs)
|
||||||
|
|
@ -279,16 +295,7 @@ impl SslError {
|
||||||
|
|
||||||
/// Creates an `SslError` from the raw numeric error code.
|
/// Creates an `SslError` from the raw numeric error code.
|
||||||
pub fn from_error(err: c_ulong) -> SslError {
|
pub fn from_error(err: c_ulong) -> SslError {
|
||||||
OpenSslErrors(vec![SslError::from_error_code(err)])
|
OpenSslErrors(vec![OpensslError::from_error_code(err)])
|
||||||
}
|
|
||||||
|
|
||||||
fn from_error_code(err: c_ulong) -> OpensslError {
|
|
||||||
ffi::init();
|
|
||||||
UnknownError {
|
|
||||||
library: get_lib(err).to_owned(),
|
|
||||||
function: get_func(err).to_owned(),
|
|
||||||
reason: get_reason(err).to_owned()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ use std::str;
|
||||||
use std::net;
|
use std::net;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::ptr;
|
use std::ptr;
|
||||||
use std::sync::{Once, ONCE_INIT, Arc, Mutex};
|
use std::sync::{Once, ONCE_INIT, Mutex};
|
||||||
use std::cmp;
|
use std::cmp;
|
||||||
use std::any::Any;
|
use std::any::Any;
|
||||||
#[cfg(any(feature = "npn", feature = "alpn"))]
|
#[cfg(any(feature = "npn", feature = "alpn"))]
|
||||||
|
|
@ -18,11 +18,16 @@ use libc::{c_uchar, c_uint};
|
||||||
#[cfg(any(feature = "npn", feature = "alpn"))]
|
#[cfg(any(feature = "npn", feature = "alpn"))]
|
||||||
use std::slice;
|
use std::slice;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
#[cfg(unix)]
|
||||||
|
use std::os::unix::io::{AsRawFd, RawFd};
|
||||||
|
#[cfg(windows)]
|
||||||
|
use std::os::windows::io::{AsRawSocket, RawSocket};
|
||||||
|
|
||||||
use ffi;
|
use ffi;
|
||||||
use ffi_extras;
|
use ffi_extras;
|
||||||
use dh::DH;
|
use dh::DH;
|
||||||
use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError};
|
use ssl::error::{NonblockingSslError, SslError, StreamError, OpenSslErrors, OpenSslError,
|
||||||
|
OpensslError};
|
||||||
use x509::{X509StoreContext, X509FileType, X509};
|
use x509::{X509StoreContext, X509FileType, X509};
|
||||||
use crypto::pkey::PKey;
|
use crypto::pkey::PKey;
|
||||||
|
|
||||||
|
|
@ -935,6 +940,20 @@ impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
impl<S: AsRawFd> AsRawFd for SslStream<S> {
|
||||||
|
fn as_raw_fd(&self) -> RawFd {
|
||||||
|
self.get_ref().as_raw_fd()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(windows)]
|
||||||
|
impl<S: AsRawSocket> AsRawSocket for NonblockingSslStream<S> {
|
||||||
|
fn as_raw_fd(&self) -> RawSocket {
|
||||||
|
self.0.as_raw_socket()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<S: Read+Write> SslStream<S> {
|
impl<S: Read+Write> SslStream<S> {
|
||||||
fn new_base(ssl: Ssl, stream: S) -> Self {
|
fn new_base(ssl: Ssl, stream: S) -> Self {
|
||||||
unsafe {
|
unsafe {
|
||||||
|
|
@ -1247,65 +1266,38 @@ impl MaybeSslStream<net::TcpStream> {
|
||||||
/// # Deprecated
|
/// # Deprecated
|
||||||
///
|
///
|
||||||
/// Use `SslStream` with `ssl_read` and `ssl_write`.
|
/// Use `SslStream` with `ssl_read` and `ssl_write`.
|
||||||
#[derive(Clone)]
|
pub struct NonblockingSslStream<S>(SslStream<S>);
|
||||||
pub struct NonblockingSslStream<S> {
|
|
||||||
stream: S,
|
impl<S: Clone + Read + Write> Clone for NonblockingSslStream<S> {
|
||||||
ssl: Arc<Ssl>,
|
fn clone(&self) -> Self {
|
||||||
|
NonblockingSslStream(self.0.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
impl<S: AsRawFd> AsRawFd for NonblockingSslStream<S> {
|
||||||
|
fn as_raw_fd(&self) -> RawFd {
|
||||||
|
self.0.as_raw_fd()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(windows)]
|
||||||
|
impl<S: AsRawSocket> AsRawSocket for NonblockingSslStream<S> {
|
||||||
|
fn as_raw_fd(&self) -> RawSocket {
|
||||||
|
self.0.as_raw_socket()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl NonblockingSslStream<net::TcpStream> {
|
impl NonblockingSslStream<net::TcpStream> {
|
||||||
pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> {
|
pub fn try_clone(&self) -> io::Result<NonblockingSslStream<net::TcpStream>> {
|
||||||
Ok(NonblockingSslStream {
|
self.0.try_clone().map(NonblockingSslStream)
|
||||||
stream: try!(self.stream.try_clone()),
|
|
||||||
ssl: self.ssl.clone(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> NonblockingSslStream<S> {
|
impl<S> NonblockingSslStream<S> {
|
||||||
fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<NonblockingSslStream<S>, SslError> {
|
|
||||||
unsafe {
|
|
||||||
let bio = try_ssl_null!(ffi::BIO_new_socket(sock, 0));
|
|
||||||
ffi_extras::BIO_set_nbio(bio, 1);
|
|
||||||
ffi::SSL_set_bio(ssl.ssl, bio, bio);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(NonblockingSslStream {
|
|
||||||
stream: stream,
|
|
||||||
ssl: Arc::new(ssl),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn make_error(&self, ret: c_int) -> NonblockingSslError {
|
|
||||||
match self.ssl.get_error(ret) {
|
|
||||||
LibSslError::ErrorSsl => NonblockingSslError::SslError(SslError::get()),
|
|
||||||
LibSslError::ErrorSyscall => {
|
|
||||||
let err = SslError::get();
|
|
||||||
let count = match err {
|
|
||||||
SslError::OpenSslErrors(ref v) => v.len(),
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
let ssl_error = if count == 0 {
|
|
||||||
if ret == 0 {
|
|
||||||
SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted,
|
|
||||||
"unexpected EOF observed"))
|
|
||||||
} else {
|
|
||||||
SslError::StreamError(io::Error::last_os_error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err
|
|
||||||
};
|
|
||||||
ssl_error.into()
|
|
||||||
},
|
|
||||||
LibSslError::ErrorWantWrite => NonblockingSslError::WantWrite,
|
|
||||||
LibSslError::ErrorWantRead => NonblockingSslError::WantRead,
|
|
||||||
err => panic!("unexpected error {:?} with ret {}", err, ret),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a reference to the underlying stream.
|
/// Returns a reference to the underlying stream.
|
||||||
pub fn get_ref(&self) -> &S {
|
pub fn get_ref(&self) -> &S {
|
||||||
&self.stream
|
self.0.get_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a mutable reference to the underlying stream.
|
/// Returns a mutable reference to the underlying stream.
|
||||||
|
|
@ -1315,117 +1307,50 @@ impl<S> NonblockingSslStream<S> {
|
||||||
/// It is inadvisable to read from or write to the underlying stream as it
|
/// It is inadvisable to read from or write to the underlying stream as it
|
||||||
/// will most likely corrupt the SSL session.
|
/// will most likely corrupt the SSL session.
|
||||||
pub fn get_mut(&mut self) -> &mut S {
|
pub fn get_mut(&mut self) -> &mut S {
|
||||||
&mut self.stream
|
self.0.get_mut()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a reference to the Ssl.
|
/// Returns a reference to the Ssl.
|
||||||
pub fn ssl(&self) -> &Ssl {
|
pub fn ssl(&self) -> &Ssl {
|
||||||
&self.ssl
|
self.0.ssl()
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(unix)]
|
|
||||||
impl<S: Read+Write+::std::os::unix::io::AsRawFd> NonblockingSslStream<S> {
|
|
||||||
/// Create a new nonblocking client ssl connection on wrapped `stream`.
|
|
||||||
///
|
|
||||||
/// Note that this method will most likely not actually complete the SSL
|
|
||||||
/// handshake because doing so requires several round trips; the handshake will
|
|
||||||
/// be completed in subsequent read/write calls managed by your event loop.
|
|
||||||
pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
|
|
||||||
let ssl = try!(ssl.into_ssl());
|
|
||||||
let fd = stream.as_raw_fd() as c_int;
|
|
||||||
let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
|
|
||||||
let ret = ssl.ssl.connect();
|
|
||||||
if ret > 0 {
|
|
||||||
Ok(ssl)
|
|
||||||
} else {
|
|
||||||
// WantRead/WantWrite is okay here; we'll finish the handshake in
|
|
||||||
// subsequent send/recv calls.
|
|
||||||
match ssl.make_error(ret) {
|
|
||||||
NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
|
|
||||||
NonblockingSslError::SslError(other) => Err(other),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new nonblocking server ssl connection on wrapped `stream`.
|
|
||||||
///
|
|
||||||
/// Note that this method will most likely not actually complete the SSL
|
|
||||||
/// handshake because doing so requires several round trips; the handshake will
|
|
||||||
/// be completed in subsequent read/write calls managed by your event loop.
|
|
||||||
pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
|
|
||||||
let ssl = try!(ssl.into_ssl());
|
|
||||||
let fd = stream.as_raw_fd() as c_int;
|
|
||||||
let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
|
|
||||||
let ret = ssl.ssl.accept();
|
|
||||||
if ret > 0 {
|
|
||||||
Ok(ssl)
|
|
||||||
} else {
|
|
||||||
// WantRead/WantWrite is okay here; we'll finish the handshake in
|
|
||||||
// subsequent send/recv calls.
|
|
||||||
match ssl.make_error(ret) {
|
|
||||||
NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
|
|
||||||
NonblockingSslError::SslError(other) => Err(other),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(unix)]
|
|
||||||
impl<S: ::std::os::unix::io::AsRawFd> ::std::os::unix::io::AsRawFd for NonblockingSslStream<S> {
|
|
||||||
fn as_raw_fd(&self) -> ::std::os::unix::io::RawFd {
|
|
||||||
self.stream.as_raw_fd()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(windows)]
|
|
||||||
impl<S: Read+Write+::std::os::windows::io::AsRawSocket> NonblockingSslStream<S> {
|
|
||||||
/// Create a new nonblocking client ssl connection on wrapped `stream`.
|
|
||||||
///
|
|
||||||
/// Note that this method will most likely not actually complete the SSL
|
|
||||||
/// handshake because doing so requires several round trips; the handshake will
|
|
||||||
/// be completed in subsequent read/write calls managed by your event loop.
|
|
||||||
pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
|
|
||||||
let ssl = try!(ssl.into_ssl());
|
|
||||||
let fd = stream.as_raw_socket() as c_int;
|
|
||||||
let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
|
|
||||||
let ret = ssl.ssl.connect();
|
|
||||||
if ret > 0 {
|
|
||||||
Ok(ssl)
|
|
||||||
} else {
|
|
||||||
// WantRead/WantWrite is okay here; we'll finish the handshake in
|
|
||||||
// subsequent send/recv calls.
|
|
||||||
match ssl.make_error(ret) {
|
|
||||||
NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
|
|
||||||
NonblockingSslError::SslError(other) => Err(other),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new nonblocking server ssl connection on wrapped `stream`.
|
|
||||||
///
|
|
||||||
/// Note that this method will most likely not actually complete the SSL
|
|
||||||
/// handshake because doing so requires several round trips; the handshake will
|
|
||||||
/// be completed in subsequent read/write calls managed by your event loop.
|
|
||||||
pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
|
|
||||||
let ssl = try!(ssl.into_ssl());
|
|
||||||
let fd = stream.as_raw_socket() as c_int;
|
|
||||||
let ssl = try!(NonblockingSslStream::new_base(ssl, stream, fd));
|
|
||||||
let ret = ssl.ssl.accept();
|
|
||||||
if ret > 0 {
|
|
||||||
Ok(ssl)
|
|
||||||
} else {
|
|
||||||
// WantRead/WantWrite is okay here; we'll finish the handshake in
|
|
||||||
// subsequent send/recv calls.
|
|
||||||
match ssl.make_error(ret) {
|
|
||||||
NonblockingSslError::WantRead | NonblockingSslError::WantWrite => Ok(ssl),
|
|
||||||
NonblockingSslError::SslError(other) => Err(other),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: Read+Write> NonblockingSslStream<S> {
|
impl<S: Read+Write> NonblockingSslStream<S> {
|
||||||
|
/// Create a new nonblocking client ssl connection on wrapped `stream`.
|
||||||
|
///
|
||||||
|
/// Note that this method will most likely not actually complete the SSL
|
||||||
|
/// handshake because doing so requires several round trips; the handshake will
|
||||||
|
/// be completed in subsequent read/write calls managed by your event loop.
|
||||||
|
pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
|
||||||
|
SslStream::connect(ssl, stream).map(NonblockingSslStream)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new nonblocking server ssl connection on wrapped `stream`.
|
||||||
|
///
|
||||||
|
/// Note that this method will most likely not actually complete the SSL
|
||||||
|
/// handshake because doing so requires several round trips; the handshake will
|
||||||
|
/// be completed in subsequent read/write calls managed by your event loop.
|
||||||
|
pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<NonblockingSslStream<S>, SslError> {
|
||||||
|
SslStream::accept(ssl, stream).map(NonblockingSslStream)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_err(&self, err: Error) -> NonblockingSslError {
|
||||||
|
match err {
|
||||||
|
Error::ZeroReturn => SslError::SslSessionClosed.into(),
|
||||||
|
Error::WantRead(_) => NonblockingSslError::WantRead,
|
||||||
|
Error::WantWrite(_) => NonblockingSslError::WantWrite,
|
||||||
|
Error::WantX509Lookup => unreachable!(),
|
||||||
|
Error::Stream(e) => SslError::StreamError(e).into(),
|
||||||
|
Error::Ssl(e) => {
|
||||||
|
SslError::OpenSslErrors(e.iter()
|
||||||
|
.map(|e| OpensslError::from_error_code(e.error_code()))
|
||||||
|
.collect())
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Read bytes from the SSL stream into `buf`.
|
/// Read bytes from the SSL stream into `buf`.
|
||||||
///
|
///
|
||||||
/// Given the SSL state machine, this method may return either `WantWrite`
|
/// Given the SSL state machine, this method may return either `WantWrite`
|
||||||
|
|
@ -1442,11 +1367,10 @@ impl<S: Read+Write> NonblockingSslStream<S> {
|
||||||
/// On a return value of `Ok(count)`, count is the number of decrypted
|
/// On a return value of `Ok(count)`, count is the number of decrypted
|
||||||
/// plaintext bytes copied into the `buf` slice.
|
/// plaintext bytes copied into the `buf` slice.
|
||||||
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, NonblockingSslError> {
|
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, NonblockingSslError> {
|
||||||
let ret = self.ssl.read(buf);
|
match self.0.ssl_read(buf) {
|
||||||
if ret >= 0 {
|
Ok(n) => Ok(n),
|
||||||
Ok(ret as usize)
|
Err(Error::ZeroReturn) => Ok(0),
|
||||||
} else {
|
Err(e) => Err(self.convert_err(e))
|
||||||
Err(self.make_error(ret))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1466,11 +1390,6 @@ impl<S: Read+Write> NonblockingSslStream<S> {
|
||||||
/// Given a return value of `Ok(count)`, count is the number of plaintext bytes
|
/// Given a return value of `Ok(count)`, count is the number of plaintext bytes
|
||||||
/// from the `buf` slice that were encrypted and written onto the stream.
|
/// from the `buf` slice that were encrypted and written onto the stream.
|
||||||
pub fn write(&mut self, buf: &[u8]) -> Result<usize, NonblockingSslError> {
|
pub fn write(&mut self, buf: &[u8]) -> Result<usize, NonblockingSslError> {
|
||||||
let ret = self.ssl.write(buf);
|
self.0.ssl_write(buf).map_err(|e| self.convert_err(e))
|
||||||
if ret > 0 {
|
|
||||||
Ok(ret as usize)
|
|
||||||
} else {
|
|
||||||
Err(self.make_error(ret))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue