Add MidHandshakeSslStream
Allows recognizing when a stream is still in handshake mode and can gracefully transition when ready. The blocking usage of the API should still be the same, just helps nonblocking implementations!
This commit is contained in:
parent
f0ffa246b8
commit
3539be3366
|
|
@ -741,6 +741,7 @@ extern "C" {
|
|||
pub fn SSL_get_wbio(ssl: *mut SSL) -> *mut BIO;
|
||||
pub fn SSL_accept(ssl: *mut SSL) -> c_int;
|
||||
pub fn SSL_connect(ssl: *mut SSL) -> c_int;
|
||||
pub fn SSL_do_handshake(ssl: *mut SSL) -> c_int;
|
||||
pub fn SSL_ctrl(ssl: *mut SSL, cmd: c_int, larg: c_long,
|
||||
parg: *mut c_void) -> c_long;
|
||||
pub fn SSL_get_error(ssl: *mut SSL, ret: c_int) -> c_int;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use std::ffi::{CStr, CString};
|
|||
use std::fmt;
|
||||
use std::io;
|
||||
use std::io::prelude::*;
|
||||
use std::error as stderror;
|
||||
use std::mem;
|
||||
use std::str;
|
||||
use std::path::Path;
|
||||
|
|
@ -832,6 +833,10 @@ impl Ssl {
|
|||
unsafe { ffi::SSL_accept(self.ssl) }
|
||||
}
|
||||
|
||||
fn handshake(&self) -> c_int {
|
||||
unsafe { ffi::SSL_do_handshake(self.ssl) }
|
||||
}
|
||||
|
||||
fn read(&self, buf: &mut [u8]) -> c_int {
|
||||
let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int;
|
||||
unsafe { ffi::SSL_read(self.ssl, buf.as_ptr() as *mut c_void, len) }
|
||||
|
|
@ -1081,31 +1086,49 @@ impl<S: Read + Write> SslStream<S> {
|
|||
}
|
||||
|
||||
/// Creates an SSL/TLS client operating over the provided stream.
|
||||
pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> {
|
||||
let ssl = try!(ssl.into_ssl());
|
||||
pub fn connect<T: IntoSsl>(ssl: T, stream: S)
|
||||
-> Result<Self, HandshakeError<S>>{
|
||||
let ssl = try!(ssl.into_ssl().map_err(|e| {
|
||||
HandshakeError::Failure(Error::Ssl(e))
|
||||
}));
|
||||
let mut stream = Self::new_base(ssl, stream);
|
||||
let ret = stream.ssl.connect();
|
||||
if ret > 0 {
|
||||
Ok(stream)
|
||||
} else {
|
||||
match stream.make_error(ret) {
|
||||
Error::WantRead(..) | Error::WantWrite(..) => Ok(stream),
|
||||
err => Err(err)
|
||||
e @ Error::WantWrite(_) |
|
||||
e @ Error::WantRead(_) => {
|
||||
Err(HandshakeError::Interrupted(MidHandshakeSslStream {
|
||||
stream: stream,
|
||||
error: e,
|
||||
}))
|
||||
}
|
||||
err => Err(HandshakeError::Failure(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an SSL/TLS server operating over the provided stream.
|
||||
pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> {
|
||||
let ssl = try!(ssl.into_ssl());
|
||||
pub fn accept<T: IntoSsl>(ssl: T, stream: S)
|
||||
-> Result<Self, HandshakeError<S>> {
|
||||
let ssl = try!(ssl.into_ssl().map_err(|e| {
|
||||
HandshakeError::Failure(Error::Ssl(e))
|
||||
}));
|
||||
let mut stream = Self::new_base(ssl, stream);
|
||||
let ret = stream.ssl.accept();
|
||||
if ret > 0 {
|
||||
Ok(stream)
|
||||
} else {
|
||||
match stream.make_error(ret) {
|
||||
Error::WantRead(..) | Error::WantWrite(..) => Ok(stream),
|
||||
err => Err(err)
|
||||
e @ Error::WantWrite(_) |
|
||||
e @ Error::WantRead(_) => {
|
||||
Err(HandshakeError::Interrupted(MidHandshakeSslStream {
|
||||
stream: stream,
|
||||
error: e,
|
||||
}))
|
||||
}
|
||||
err => Err(HandshakeError::Failure(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1137,6 +1160,87 @@ impl<S: Read + Write> SslStream<S> {
|
|||
}
|
||||
}
|
||||
|
||||
/// An error or intermediate state after a TLS handshake attempt.
|
||||
#[derive(Debug)]
|
||||
pub enum HandshakeError<S> {
|
||||
/// The handshake failed.
|
||||
Failure(Error),
|
||||
/// The handshake was interrupted midway through.
|
||||
Interrupted(MidHandshakeSslStream<S>),
|
||||
}
|
||||
|
||||
impl<S: Any + fmt::Debug> stderror::Error for HandshakeError<S> {
|
||||
fn description(&self) -> &str {
|
||||
match *self {
|
||||
HandshakeError::Failure(ref e) => e.description(),
|
||||
HandshakeError::Interrupted(ref e) => e.error.description(),
|
||||
}
|
||||
}
|
||||
|
||||
fn cause(&self) -> Option<&stderror::Error> {
|
||||
match *self {
|
||||
HandshakeError::Failure(ref e) => Some(e),
|
||||
HandshakeError::Interrupted(ref e) => Some(&e.error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Any + fmt::Debug> fmt::Display for HandshakeError<S> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
try!(f.write_str(stderror::Error::description(self)));
|
||||
if let Some(e) = stderror::Error::cause(self) {
|
||||
try!(write!(f, ": {}", e));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// An SSL stream midway through the handshake process.
|
||||
#[derive(Debug)]
|
||||
pub struct MidHandshakeSslStream<S> {
|
||||
stream: SslStream<S>,
|
||||
error: Error,
|
||||
}
|
||||
|
||||
impl<S> MidHandshakeSslStream<S> {
|
||||
/// Returns a shared reference to the inner stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
self.stream.get_ref()
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the inner stream.
|
||||
pub fn get_mut(&mut self) -> &mut S {
|
||||
self.stream.get_mut()
|
||||
}
|
||||
|
||||
/// Returns a shared reference to the `SslContext` of the stream.
|
||||
pub fn ssl(&self) -> &Ssl {
|
||||
self.stream.ssl()
|
||||
}
|
||||
|
||||
/// Returns the underlying error which interrupted this handshake.
|
||||
pub fn error(&self) -> &Error {
|
||||
&self.error
|
||||
}
|
||||
|
||||
/// Restarts the handshake process.
|
||||
pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> {
|
||||
let ret = self.stream.ssl.handshake();
|
||||
if ret > 0 {
|
||||
Ok(self.stream)
|
||||
} else {
|
||||
match self.stream.make_error(ret) {
|
||||
e @ Error::WantWrite(_) |
|
||||
e @ Error::WantRead(_) => {
|
||||
self.error = e;
|
||||
Err(HandshakeError::Interrupted(self))
|
||||
}
|
||||
err => Err(HandshakeError::Failure(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> SslStream<S> {
|
||||
fn make_error(&mut self, ret: c_int) -> Error {
|
||||
self.check_panic();
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ use crypto::hash::Type::SHA256;
|
|||
use ssl;
|
||||
use ssl::SSL_VERIFY_PEER;
|
||||
use ssl::SslMethod::Sslv23;
|
||||
use ssl::SslMethod;
|
||||
use ssl::{SslMethod, HandshakeError};
|
||||
use ssl::error::Error;
|
||||
use ssl::{SslContext, SslStream};
|
||||
use x509::X509StoreContext;
|
||||
|
|
@ -133,6 +133,7 @@ impl Drop for Server {
|
|||
}
|
||||
|
||||
#[cfg(feature = "dtlsv1")]
|
||||
#[derive(Debug)]
|
||||
struct UdpConnected(UdpSocket);
|
||||
|
||||
#[cfg(feature = "dtlsv1")]
|
||||
|
|
@ -846,10 +847,10 @@ fn test_sslv2_connect_failure() {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
fn wait_io(stream: &SslStream<TcpStream>, read: bool, timeout_ms: u32) -> bool {
|
||||
fn wait_io(stream: &TcpStream, read: bool, timeout_ms: u32) -> bool {
|
||||
unsafe {
|
||||
let mut set: select::fd_set = mem::zeroed();
|
||||
select::fd_set(&mut set, stream.get_ref());
|
||||
select::fd_set(&mut set, stream);
|
||||
|
||||
let write = if read {
|
||||
0 as *mut _
|
||||
|
|
@ -861,7 +862,19 @@ fn wait_io(stream: &SslStream<TcpStream>, read: bool, timeout_ms: u32) -> bool {
|
|||
} else {
|
||||
&mut set as *mut _
|
||||
};
|
||||
select::select(stream.get_ref(), read, write, 0 as *mut _, timeout_ms).unwrap()
|
||||
select::select(stream, read, write, 0 as *mut _, timeout_ms).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn handshake(res: Result<SslStream<TcpStream>, HandshakeError<TcpStream>>)
|
||||
-> SslStream<TcpStream> {
|
||||
match res {
|
||||
Ok(s) => s,
|
||||
Err(HandshakeError::Interrupted(s)) => {
|
||||
wait_io(s.get_ref(), true, 1_000);
|
||||
handshake(s.handshake())
|
||||
}
|
||||
Err(err) => panic!("error on handshake {:?}", err),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -870,7 +883,7 @@ fn test_write_nonblocking() {
|
|||
let (_s, stream) = Server::new();
|
||||
stream.set_nonblocking(true).unwrap();
|
||||
let cx = SslContext::new(Sslv23).unwrap();
|
||||
let mut stream = SslStream::connect(&cx, stream).unwrap();
|
||||
let mut stream = handshake(SslStream::connect(&cx, stream));
|
||||
|
||||
let mut iterations = 0;
|
||||
loop {
|
||||
|
|
@ -886,10 +899,10 @@ fn test_write_nonblocking() {
|
|||
break;
|
||||
}
|
||||
Err(Error::WantRead(_)) => {
|
||||
assert!(wait_io(&stream, true, 1000));
|
||||
assert!(wait_io(stream.get_ref(), true, 1000));
|
||||
}
|
||||
Err(Error::WantWrite(_)) => {
|
||||
assert!(wait_io(&stream, false, 1000));
|
||||
assert!(wait_io(stream.get_ref(), false, 1000));
|
||||
}
|
||||
Err(other) => {
|
||||
panic!("Unexpected SSL Error: {:?}", other);
|
||||
|
|
@ -907,7 +920,7 @@ fn test_read_nonblocking() {
|
|||
let (_s, stream) = Server::new();
|
||||
stream.set_nonblocking(true).unwrap();
|
||||
let cx = SslContext::new(Sslv23).unwrap();
|
||||
let mut stream = SslStream::connect(&cx, stream).unwrap();
|
||||
let mut stream = handshake(SslStream::connect(&cx, stream));
|
||||
|
||||
let mut iterations = 0;
|
||||
loop {
|
||||
|
|
@ -924,10 +937,10 @@ fn test_read_nonblocking() {
|
|||
break;
|
||||
}
|
||||
Err(Error::WantRead(..)) => {
|
||||
assert!(wait_io(&stream, true, 1000));
|
||||
assert!(wait_io(stream.get_ref(), true, 1000));
|
||||
}
|
||||
Err(Error::WantWrite(..)) => {
|
||||
assert!(wait_io(&stream, false, 1000));
|
||||
assert!(wait_io(stream.get_ref(), false, 1000));
|
||||
}
|
||||
Err(other) => {
|
||||
panic!("Unexpected SSL Error: {:?}", other);
|
||||
|
|
@ -944,7 +957,7 @@ fn test_read_nonblocking() {
|
|||
n
|
||||
}
|
||||
Err(Error::WantRead(..)) => {
|
||||
assert!(wait_io(&stream, true, 3000));
|
||||
assert!(wait_io(stream.get_ref(), true, 3000));
|
||||
// Second read should return application data.
|
||||
stream.read(&mut input_buffer).unwrap()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue