Merge pull request #432 from alexcrichton/mid-handshake
Add MidHandshakeSslStream
This commit is contained in:
commit
2574bff52d
|
|
@ -741,6 +741,7 @@ extern "C" {
|
|||
pub fn SSL_get_wbio(ssl: *mut SSL) -> *mut BIO;
|
||||
pub fn SSL_accept(ssl: *mut SSL) -> c_int;
|
||||
pub fn SSL_connect(ssl: *mut SSL) -> c_int;
|
||||
pub fn SSL_do_handshake(ssl: *mut SSL) -> c_int;
|
||||
pub fn SSL_ctrl(ssl: *mut SSL, cmd: c_int, larg: c_long,
|
||||
parg: *mut c_void) -> c_long;
|
||||
pub fn SSL_get_error(ssl: *mut SSL, ret: c_int) -> c_int;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use std::ffi::{CStr, CString};
|
|||
use std::fmt;
|
||||
use std::io;
|
||||
use std::io::prelude::*;
|
||||
use std::error as stderror;
|
||||
use std::mem;
|
||||
use std::str;
|
||||
use std::path::Path;
|
||||
|
|
@ -832,6 +833,10 @@ impl Ssl {
|
|||
unsafe { ffi::SSL_accept(self.ssl) }
|
||||
}
|
||||
|
||||
fn handshake(&self) -> c_int {
|
||||
unsafe { ffi::SSL_do_handshake(self.ssl) }
|
||||
}
|
||||
|
||||
fn read(&self, buf: &mut [u8]) -> c_int {
|
||||
let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int;
|
||||
unsafe { ffi::SSL_read(self.ssl, buf.as_ptr() as *mut c_void, len) }
|
||||
|
|
@ -1081,31 +1086,49 @@ impl<S: Read + Write> SslStream<S> {
|
|||
}
|
||||
|
||||
/// Creates an SSL/TLS client operating over the provided stream.
|
||||
pub fn connect<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> {
|
||||
let ssl = try!(ssl.into_ssl());
|
||||
pub fn connect<T: IntoSsl>(ssl: T, stream: S)
|
||||
-> Result<Self, HandshakeError<S>>{
|
||||
let ssl = try!(ssl.into_ssl().map_err(|e| {
|
||||
HandshakeError::Failure(Error::Ssl(e))
|
||||
}));
|
||||
let mut stream = Self::new_base(ssl, stream);
|
||||
let ret = stream.ssl.connect();
|
||||
if ret > 0 {
|
||||
Ok(stream)
|
||||
} else {
|
||||
match stream.make_error(ret) {
|
||||
Error::WantRead(..) | Error::WantWrite(..) => Ok(stream),
|
||||
err => Err(err)
|
||||
e @ Error::WantWrite(_) |
|
||||
e @ Error::WantRead(_) => {
|
||||
Err(HandshakeError::Interrupted(MidHandshakeSslStream {
|
||||
stream: stream,
|
||||
error: e,
|
||||
}))
|
||||
}
|
||||
err => Err(HandshakeError::Failure(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an SSL/TLS server operating over the provided stream.
|
||||
pub fn accept<T: IntoSsl>(ssl: T, stream: S) -> Result<Self, Error> {
|
||||
let ssl = try!(ssl.into_ssl());
|
||||
pub fn accept<T: IntoSsl>(ssl: T, stream: S)
|
||||
-> Result<Self, HandshakeError<S>> {
|
||||
let ssl = try!(ssl.into_ssl().map_err(|e| {
|
||||
HandshakeError::Failure(Error::Ssl(e))
|
||||
}));
|
||||
let mut stream = Self::new_base(ssl, stream);
|
||||
let ret = stream.ssl.accept();
|
||||
if ret > 0 {
|
||||
Ok(stream)
|
||||
} else {
|
||||
match stream.make_error(ret) {
|
||||
Error::WantRead(..) | Error::WantWrite(..) => Ok(stream),
|
||||
err => Err(err)
|
||||
e @ Error::WantWrite(_) |
|
||||
e @ Error::WantRead(_) => {
|
||||
Err(HandshakeError::Interrupted(MidHandshakeSslStream {
|
||||
stream: stream,
|
||||
error: e,
|
||||
}))
|
||||
}
|
||||
err => Err(HandshakeError::Failure(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1137,6 +1160,87 @@ impl<S: Read + Write> SslStream<S> {
|
|||
}
|
||||
}
|
||||
|
||||
/// An error or intermediate state after a TLS handshake attempt.
|
||||
#[derive(Debug)]
|
||||
pub enum HandshakeError<S> {
|
||||
/// The handshake failed.
|
||||
Failure(Error),
|
||||
/// The handshake was interrupted midway through.
|
||||
Interrupted(MidHandshakeSslStream<S>),
|
||||
}
|
||||
|
||||
impl<S: Any + fmt::Debug> stderror::Error for HandshakeError<S> {
|
||||
fn description(&self) -> &str {
|
||||
match *self {
|
||||
HandshakeError::Failure(ref e) => e.description(),
|
||||
HandshakeError::Interrupted(ref e) => e.error.description(),
|
||||
}
|
||||
}
|
||||
|
||||
fn cause(&self) -> Option<&stderror::Error> {
|
||||
match *self {
|
||||
HandshakeError::Failure(ref e) => Some(e),
|
||||
HandshakeError::Interrupted(ref e) => Some(&e.error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Any + fmt::Debug> fmt::Display for HandshakeError<S> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
try!(f.write_str(stderror::Error::description(self)));
|
||||
if let Some(e) = stderror::Error::cause(self) {
|
||||
try!(write!(f, ": {}", e));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// An SSL stream midway through the handshake process.
|
||||
#[derive(Debug)]
|
||||
pub struct MidHandshakeSslStream<S> {
|
||||
stream: SslStream<S>,
|
||||
error: Error,
|
||||
}
|
||||
|
||||
impl<S> MidHandshakeSslStream<S> {
|
||||
/// Returns a shared reference to the inner stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
self.stream.get_ref()
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the inner stream.
|
||||
pub fn get_mut(&mut self) -> &mut S {
|
||||
self.stream.get_mut()
|
||||
}
|
||||
|
||||
/// Returns a shared reference to the `SslContext` of the stream.
|
||||
pub fn ssl(&self) -> &Ssl {
|
||||
self.stream.ssl()
|
||||
}
|
||||
|
||||
/// Returns the underlying error which interrupted this handshake.
|
||||
pub fn error(&self) -> &Error {
|
||||
&self.error
|
||||
}
|
||||
|
||||
/// Restarts the handshake process.
|
||||
pub fn handshake(mut self) -> Result<SslStream<S>, HandshakeError<S>> {
|
||||
let ret = self.stream.ssl.handshake();
|
||||
if ret > 0 {
|
||||
Ok(self.stream)
|
||||
} else {
|
||||
match self.stream.make_error(ret) {
|
||||
e @ Error::WantWrite(_) |
|
||||
e @ Error::WantRead(_) => {
|
||||
self.error = e;
|
||||
Err(HandshakeError::Interrupted(self))
|
||||
}
|
||||
err => Err(HandshakeError::Failure(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> SslStream<S> {
|
||||
fn make_error(&mut self, ret: c_int) -> Error {
|
||||
self.check_panic();
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ use crypto::hash::Type::SHA256;
|
|||
use ssl;
|
||||
use ssl::SSL_VERIFY_PEER;
|
||||
use ssl::SslMethod::Sslv23;
|
||||
use ssl::SslMethod;
|
||||
use ssl::{SslMethod, HandshakeError};
|
||||
use ssl::error::Error;
|
||||
use ssl::{SslContext, SslStream};
|
||||
use x509::X509StoreContext;
|
||||
|
|
@ -133,6 +133,7 @@ impl Drop for Server {
|
|||
}
|
||||
|
||||
#[cfg(feature = "dtlsv1")]
|
||||
#[derive(Debug)]
|
||||
struct UdpConnected(UdpSocket);
|
||||
|
||||
#[cfg(feature = "dtlsv1")]
|
||||
|
|
@ -846,10 +847,10 @@ fn test_sslv2_connect_failure() {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
fn wait_io(stream: &SslStream<TcpStream>, read: bool, timeout_ms: u32) -> bool {
|
||||
fn wait_io(stream: &TcpStream, read: bool, timeout_ms: u32) -> bool {
|
||||
unsafe {
|
||||
let mut set: select::fd_set = mem::zeroed();
|
||||
select::fd_set(&mut set, stream.get_ref());
|
||||
select::fd_set(&mut set, stream);
|
||||
|
||||
let write = if read {
|
||||
0 as *mut _
|
||||
|
|
@ -861,7 +862,19 @@ fn wait_io(stream: &SslStream<TcpStream>, read: bool, timeout_ms: u32) -> bool {
|
|||
} else {
|
||||
&mut set as *mut _
|
||||
};
|
||||
select::select(stream.get_ref(), read, write, 0 as *mut _, timeout_ms).unwrap()
|
||||
select::select(stream, read, write, 0 as *mut _, timeout_ms).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn handshake(res: Result<SslStream<TcpStream>, HandshakeError<TcpStream>>)
|
||||
-> SslStream<TcpStream> {
|
||||
match res {
|
||||
Ok(s) => s,
|
||||
Err(HandshakeError::Interrupted(s)) => {
|
||||
wait_io(s.get_ref(), true, 1_000);
|
||||
handshake(s.handshake())
|
||||
}
|
||||
Err(err) => panic!("error on handshake {:?}", err),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -870,7 +883,7 @@ fn test_write_nonblocking() {
|
|||
let (_s, stream) = Server::new();
|
||||
stream.set_nonblocking(true).unwrap();
|
||||
let cx = SslContext::new(Sslv23).unwrap();
|
||||
let mut stream = SslStream::connect(&cx, stream).unwrap();
|
||||
let mut stream = handshake(SslStream::connect(&cx, stream));
|
||||
|
||||
let mut iterations = 0;
|
||||
loop {
|
||||
|
|
@ -886,10 +899,10 @@ fn test_write_nonblocking() {
|
|||
break;
|
||||
}
|
||||
Err(Error::WantRead(_)) => {
|
||||
assert!(wait_io(&stream, true, 1000));
|
||||
assert!(wait_io(stream.get_ref(), true, 1000));
|
||||
}
|
||||
Err(Error::WantWrite(_)) => {
|
||||
assert!(wait_io(&stream, false, 1000));
|
||||
assert!(wait_io(stream.get_ref(), false, 1000));
|
||||
}
|
||||
Err(other) => {
|
||||
panic!("Unexpected SSL Error: {:?}", other);
|
||||
|
|
@ -907,7 +920,7 @@ fn test_read_nonblocking() {
|
|||
let (_s, stream) = Server::new();
|
||||
stream.set_nonblocking(true).unwrap();
|
||||
let cx = SslContext::new(Sslv23).unwrap();
|
||||
let mut stream = SslStream::connect(&cx, stream).unwrap();
|
||||
let mut stream = handshake(SslStream::connect(&cx, stream));
|
||||
|
||||
let mut iterations = 0;
|
||||
loop {
|
||||
|
|
@ -924,10 +937,10 @@ fn test_read_nonblocking() {
|
|||
break;
|
||||
}
|
||||
Err(Error::WantRead(..)) => {
|
||||
assert!(wait_io(&stream, true, 1000));
|
||||
assert!(wait_io(stream.get_ref(), true, 1000));
|
||||
}
|
||||
Err(Error::WantWrite(..)) => {
|
||||
assert!(wait_io(&stream, false, 1000));
|
||||
assert!(wait_io(stream.get_ref(), false, 1000));
|
||||
}
|
||||
Err(other) => {
|
||||
panic!("Unexpected SSL Error: {:?}", other);
|
||||
|
|
@ -944,7 +957,7 @@ fn test_read_nonblocking() {
|
|||
n
|
||||
}
|
||||
Err(Error::WantRead(..)) => {
|
||||
assert!(wait_io(&stream, true, 3000));
|
||||
assert!(wait_io(stream.get_ref(), true, 3000));
|
||||
// Second read should return application data.
|
||||
stream.read(&mut input_buffer).unwrap()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue