From 1373a76ce12d6a856b6caae7457ceb3eb5ad4122 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 28 Jun 2015 00:06:14 -0700 Subject: [PATCH] Implement direct IO support --- openssl/src/ssl/mod.rs | 181 ++++++++++++++++++++++++++++++++++++--- openssl/src/ssl/tests.rs | 19 +++- 2 files changed, 187 insertions(+), 13 deletions(-) diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 18acf7f8..0e1e5b30 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -603,11 +603,6 @@ impl Ssl { return Err(SslError::get()); } let ssl = Ssl { ssl: ssl }; - - let rbio = try!(MemBio::new()); - let wbio = try!(MemBio::new()); - - unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) } Ok(ssl) } @@ -769,6 +764,12 @@ impl IndirectStream { impl IndirectStream { fn new_base(ssl: T, stream: S) -> Result, SslError> { let ssl = try!(ssl.into_ssl()); + + let rbio = try!(MemBio::new()); + let wbio = try!(MemBio::new()); + + unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) } + Ok(IndirectStream { stream: stream, ssl: Arc::new(ssl), @@ -852,27 +853,139 @@ impl Write for IndirectStream { } } +#[derive(Clone)] +struct DirectStream { + stream: S, + ssl: Arc, +} + +impl DirectStream { + fn try_clone(&self) -> io::Result> { + Ok(DirectStream { + stream: try!(self.stream.try_clone()), + ssl: self.ssl.clone(), + }) + } +} + +impl DirectStream { + fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result, SslError> { + unsafe { + let bio = ffi::BIO_new_socket(sock, 0); + if bio == ptr::null_mut() { + return Err(SslError::get()); + } + ffi::SSL_set_bio(ssl.ssl, bio, bio); + } + + Ok(DirectStream { + stream: stream, + ssl: Arc::new(ssl), + }) + } + + fn new_client(ssl: Ssl, stream: S, sock: c_int) -> Result, SslError> { + let ssl = try!(DirectStream::new_base(ssl, stream, sock)); + let ret = ssl.ssl.connect(); + if ret > 0 { + Ok(ssl) + } else { + Err(ssl.make_error(ret)) + } + } + + fn new_server(ssl: Ssl, stream: S, sock: c_int) -> Result, SslError> { + let ssl = try!(DirectStream::new_base(ssl, stream, sock)); + let ret = ssl.ssl.accept(); + if ret > 0 { + Ok(ssl) + } else { + Err(ssl.make_error(ret)) + } + } + + fn make_error(&self, ret: c_int) -> SslError { + match self.ssl.get_error(ret) { + LibSslError::ErrorSsl => SslError::get(), + LibSslError::ErrorSyscall => { + let err = SslError::get(); + let count = match err { + SslError::OpenSslErrors(ref v) => v.len(), + _ => unreachable!(), + }; + if count == 0 { + if ret == 0 { + SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted, + "unexpected EOF observed")) + } else { + SslError::StreamError(io::Error::last_os_error()) + } + } else { + err + } + } + err => panic!("unexpected error {:?} with ret {}", err, ret), + } + } +} + +impl Read for DirectStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let ret = self.ssl.read(buf); + if ret >= 0 { + return Ok(ret as usize); + } + + match self.make_error(ret) { + SslError::StreamError(e) => Err(e), + e => Err(io::Error::new(io::ErrorKind::Other, e)), + } + } +} + +impl Write for DirectStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + let ret = self.ssl.write(buf); + if ret > 0 { + return Ok(ret as usize); + } + + match self.make_error(ret) { + SslError::StreamError(e) => Err(e), + e => Err(io::Error::new(io::ErrorKind::Other, e)), + } + } + + fn flush(&mut self) -> io::Result<()> { + self.stream.flush() + } +} + #[derive(Clone)] enum StreamKind { Indirect(IndirectStream), + Direct(DirectStream), } impl StreamKind { fn stream(&self) -> &S { match *self { - StreamKind::Indirect(ref s) => &s.stream + StreamKind::Indirect(ref s) => &s.stream, + StreamKind::Direct(ref s) => &s.stream, } } fn mut_stream(&mut self) -> &mut S { match *self { - StreamKind::Indirect(ref mut s) => &mut s.stream + StreamKind::Indirect(ref mut s) => &mut s.stream, + StreamKind::Direct(ref mut s) => &mut s.stream, } } fn ssl(&self) -> &Ssl { match *self { - StreamKind::Indirect(ref s) => &s.ssl + StreamKind::Indirect(ref s) => &s.ssl, + StreamKind::Direct(ref s) => &s.ssl, } } } @@ -887,7 +1000,8 @@ impl SslStream { /// Create a new independently owned handle to the underlying socket. pub fn try_clone(&self) -> io::Result> { let kind = match self.kind { - StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())) + StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())), + StreamKind::Direct(ref s) => StreamKind::Direct(try!(s.try_clone())) }; Ok(SslStream { kind: kind @@ -901,6 +1015,46 @@ impl fmt::Debug for SslStream where S: fmt::Debug { } } +#[cfg(unix)] +impl SslStream { + pub fn new_client_direct(ssl: T, stream: S) -> Result, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_fd() as c_int; + let stream = try!(DirectStream::new_client(ssl, stream, fd)); + Ok(SslStream { + kind: StreamKind::Direct(stream) + }) + } + + pub fn new_server_direct(ssl: T, stream: S) -> Result, SslError> { + let ssl = try!(ssl.into_ssl()); + let fd = stream.as_raw_fd() as c_int; + let stream = try!(DirectStream::new_server(ssl, stream, fd)); + Ok(SslStream { + kind: StreamKind::Direct(stream) + }) + } +} + +#[cfg(windows)] +impl SslStream { + pub fn new_client_direct(ssl: T, stream: S) -> Result, SslError> { + let fd = stream.as_raw_socket() as c_int; + let stream = try!(DirectStream::new_client(ssl, stream, fd)); + Ok(SslStream { + kind: StreamKind::Direct(stream) + }) + } + + pub fn new_server_direct(ssl: T, stream: S) -> Result, SslError> { + let fd = stream.as_raw_socket() as c_int; + let stream = try!(DirectStream::new_server(ssl, stream, fd)); + Ok(SslStream { + kind: StreamKind::Direct(stream) + }) + } +} + impl SslStream { pub fn new_client(ssl: T, stream: S) -> Result, SslError> { let stream = try!(IndirectStream::new_client(ssl, stream)); @@ -994,7 +1148,8 @@ impl SslStream { impl Read for SslStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self.kind { - StreamKind::Indirect(ref mut s) => s.read(buf) + StreamKind::Indirect(ref mut s) => s.read(buf), + StreamKind::Direct(ref mut s) => s.read(buf), } } } @@ -1002,13 +1157,15 @@ impl Read for SslStream { impl Write for SslStream { fn write(&mut self, buf: &[u8]) -> io::Result { match self.kind { - StreamKind::Indirect(ref mut s) => s.write(buf) + StreamKind::Indirect(ref mut s) => s.write(buf), + StreamKind::Direct(ref mut s) => s.write(buf), } } fn flush(&mut self) -> io::Result<()> { match self.kind { - StreamKind::Indirect(ref mut s) => s.flush() + StreamKind::Indirect(ref mut s) => s.flush(), + StreamKind::Direct(ref mut s) => s.flush(), } } } diff --git a/openssl/src/ssl/tests.rs b/openssl/src/ssl/tests.rs index a0e4a9d6..2ba940ab 100644 --- a/openssl/src/ssl/tests.rs +++ b/openssl/src/ssl/tests.rs @@ -317,8 +317,17 @@ fn test_write() { stream.flush().unwrap(); } +#[test] +fn test_write_direct() { + let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); + let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); + stream.write_all("hello".as_bytes()).unwrap(); + stream.flush().unwrap(); + stream.write_all(" there".as_bytes()).unwrap(); + stream.flush().unwrap(); +} + run_test!(get_peer_certificate, |method, stream| { - //let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); let stream = SslStream::new_client(&SslContext::new(method).unwrap(), stream).unwrap(); let cert = stream.get_peer_certificate().unwrap(); let fingerprint = cert.fingerprint(SHA256).unwrap(); @@ -349,6 +358,14 @@ fn test_read() { io::copy(&mut stream, &mut io::sink()).ok().expect("read error"); } +#[test] +fn test_read_direct() { + let tcp = TcpStream::connect("127.0.0.1:15418").unwrap(); + let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), tcp).unwrap(); + stream.write_all("GET /\r\n\r\n".as_bytes()).unwrap(); + stream.flush().unwrap(); + io::copy(&mut stream, &mut io::sink()).ok().expect("read error"); +} #[test] fn test_pending() {