diff --git a/src/ssl/ffi.rs b/src/ssl/ffi.rs index 07591b78..e12bb3c2 100644 --- a/src/ssl/ffi.rs +++ b/src/ssl/ffi.rs @@ -1,6 +1,6 @@ #[doc(hidden)]; -use std::libc::{c_int, c_long, c_void}; +use std::libc::{c_int, c_void}; // openssl/ssl.h pub type SSL_CTX = c_void; @@ -35,6 +35,8 @@ externfn!(fn SSL_set_bio(ssl: *SSL, rbio: *BIO, wbio: *BIO)) externfn!(fn SSL_set_connect_state(ssl: *SSL)) externfn!(fn SSL_connect(ssl: *SSL) -> c_int) externfn!(fn SSL_get_error(ssl: *SSL, ret: c_int) -> c_int) +externfn!(fn SSL_read(ssl: *SSL, buf: *c_void, num: c_int) -> c_int) +externfn!(fn SSL_write(ssl: *SSL, buf: *c_void, num: c_int) -> c_int) externfn!(fn BIO_s_mem() -> *BIO_METHOD) externfn!(fn BIO_new(type_: *BIO_METHOD) -> *BIO) diff --git a/src/ssl/lib.rs b/src/ssl/lib.rs index 401a7b3d..5cc59330 100644 --- a/src/ssl/lib.rs +++ b/src/ssl/lib.rs @@ -1,4 +1,4 @@ -use std::rt::io::{Stream, Decorator}; +use std::rt::io::{Reader, Writer, Stream, Decorator}; use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release}; use std::task; use std::ptr; @@ -70,6 +70,7 @@ impl Drop for Ssl { } } +#[deriving(Eq, TotalEq, ToStr)] enum SslError { ErrorNone, ErrorSsl, @@ -114,6 +115,20 @@ impl Ssl { _ => unreachable!() } } + + fn read(&self, buf: &[u8]) -> int { + unsafe { + ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void, + buf.len() as c_int) as int + } + } + + fn write(&self, buf: &[u8]) -> int { + unsafe { + ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void, + buf.len() as c_int) as int + } + } } struct MemBio { @@ -150,9 +165,10 @@ impl MemBio { let ret = ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int); if ret < 0 { - fail2!("read returned {}", ret); + 0 + } else { + ret as uint } - ret as uint } } } @@ -179,47 +195,86 @@ impl SslStream { let mut stream = SslStream { ctx: ctx, ssl: ssl, + // Max record size for SSLv3/TLSv1 is 16k buf: vec::from_elem(16 * 1024, 0u8), rbio: rbio, wbio: wbio, stream: stream }; - stream.connect(); + do stream.in_retry_wrapper |ssl| { + ssl.ssl.connect() + }; stream } - fn connect(&mut self) { - info!("in connect"); + fn in_retry_wrapper(&mut self, blk: &fn(&mut SslStream) -> int) + -> Result { loop { - let ret = self.ssl.connect(); - info2!("connect returned {}", ret); - if ret == 1 { - return; + let ret = blk(self); + if ret > 0 { + return Ok(ret); } match self.ssl.get_error(ret) { ErrorWantRead => { - info2!("want read"); self.flush(); match self.stream.read(self.buf) { Some(len) => self.rbio.write(self.buf.slice_to(len)), - None => unreachable!() + None => unreachable!() // FIXME } } - ErrorWantWrite => { - info2!("want write"); - self.flush(); - } - _ => unreachable!() + ErrorWantWrite => self.flush(), + err => return Err(err) } } } + fn write_through(&mut self) { + loop { + let len = self.wbio.read(self.buf); + if len == 0 { + return; + } + self.stream.write(self.buf.slice_to(len)); + } + } +} + +impl Reader for SslStream { + fn read(&mut self, buf: &mut [u8]) -> Option { + let ret = do self.in_retry_wrapper |ssl| { + ssl.ssl.read(buf) + }; + + match ret { + Ok(num) => Some(num as uint), + Err(_) => None + } + } + + fn eof(&mut self) -> bool { + self.stream.eof() + } +} + +impl Writer for SslStream { + fn write(&mut self, buf: &[u8]) { + let ret = do self.in_retry_wrapper |ssl| { + ssl.ssl.write(buf) + }; + + match ret { + Ok(_) => (), + Err(err) => fail2!("Write error: {}", err.to_str()) + } + + self.write_through(); + } + fn flush(&mut self) { - let len = self.wbio.read(self.buf); - self.stream.write(self.buf.slice_to(len)); + self.write_through(); self.stream.flush(); } } diff --git a/src/ssl/test.rs b/src/ssl/test.rs index a86772a7..7a6dc997 100644 --- a/src/ssl/test.rs +++ b/src/ssl/test.rs @@ -1,6 +1,8 @@ extern mod ssl; +use std::rt::io::{Writer}; use std::rt::io::net::tcp::TcpStream; +use std::vec; use ssl::{Sslv23, SslCtx, SslStream}; @@ -14,3 +16,11 @@ fn test_new_sslstream() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()); SslStream::new(SslCtx::new(Sslv23), stream); } + +#[test] +fn test_write() { + let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()); + let mut stream = SslStream::new(SslCtx::new(Sslv23), stream); + stream.write("hello".as_bytes()); + stream.flush(); +}