diff --git a/error.rs b/error.rs new file mode 100644 index 00000000..b5fe0b6b --- /dev/null +++ b/error.rs @@ -0,0 +1,18 @@ +use std::libc::c_ulong; + +use super::ffi; + +pub enum SslError { + StreamEof, + SslSessionClosed, + UnknownError(c_ulong) +} + +impl SslError { + pub fn get() -> Option { + match unsafe { ffi::ERR_get_error() } { + 0 => None, + err => Some(UnknownError(err)) + } + } +} diff --git a/ffi.rs b/ffi.rs index 57adfaf4..7fec9ad0 100644 --- a/ffi.rs +++ b/ffi.rs @@ -45,6 +45,8 @@ externfn!(fn SSL_CTX_load_verify_locations(ctx: *SSL_CTX, CAfile: *c_char, externfn!(fn SSL_new(ctx: *SSL_CTX) -> *SSL) externfn!(fn SSL_free(ssl: *SSL)) externfn!(fn SSL_set_bio(ssl: *SSL, rbio: *BIO, wbio: *BIO)) +externfn!(fn SSL_get_rbio(ssl: *SSL) -> *BIO) +externfn!(fn SSL_get_wbio(ssl: *SSL) -> *BIO) externfn!(fn SSL_set_connect_state(ssl: *SSL)) externfn!(fn SSL_connect(ssl: *SSL) -> c_int) externfn!(fn SSL_get_error(ssl: *SSL, ret: c_int) -> c_int) @@ -54,5 +56,6 @@ externfn!(fn SSL_shutdown(ssl: *SSL) -> c_int) externfn!(fn BIO_s_mem() -> *BIO_METHOD) externfn!(fn BIO_new(type_: *BIO_METHOD) -> *BIO) +externfn!(fn BIO_free_all(a: *BIO)) externfn!(fn BIO_read(b: *BIO, buf: *c_void, len: c_int) -> c_int) externfn!(fn BIO_write(b: *BIO, buf: *c_void, len: c_int) -> c_int) diff --git a/lib.rs b/lib.rs index 9c5f5845..ce3cc8db 100644 --- a/lib.rs +++ b/lib.rs @@ -1,15 +1,19 @@ -use std::rt::io::{Reader, Writer, Stream, Decorator}; -use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release}; -use std::task; -use std::ptr; -use std::vec; use std::libc::{c_int, c_void}; +use std::ptr; +use std::task; +use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release}; +use std::rt::io::{Stream, Reader, Writer, Decorator}; +use std::vec; -mod ffi; +use error::{SslError, SslSessionClosed, StreamEof}; + +pub mod error; #[cfg(test)] mod tests; +mod ffi; + static mut STARTED_INIT: AtomicBool = INIT_ATOMIC_BOOL; static mut FINISHED_INIT: AtomicBool = INIT_ATOMIC_BOOL; @@ -35,7 +39,7 @@ pub enum SslMethod { } impl SslMethod { - unsafe fn to_fn(&self) -> *ffi::SSL_METHOD { + unsafe fn to_raw(&self) -> *ffi::SSL_METHOD { match *self { Sslv2 => ffi::SSLv2_method(), Sslv3 => ffi::SSLv3_method(), @@ -45,47 +49,140 @@ impl SslMethod { } } -pub struct SslCtx { +pub enum SslVerifyMode { + SslVerifyPeer = ffi::SSL_VERIFY_PEER, + SslVerifyNone = ffi::SSL_VERIFY_NONE +} + +pub struct SslContext { priv ctx: *ffi::SSL_CTX } -impl Drop for SslCtx { +impl Drop for SslContext { fn drop(&mut self) { - unsafe { ffi::SSL_CTX_free(self.ctx); } + unsafe { ffi::SSL_CTX_free(self.ctx) } } } -impl SslCtx { - pub fn new(method: SslMethod) -> SslCtx { +impl SslContext { + pub fn try_new(method: SslMethod) -> Result { init(); - let ctx = unsafe { ffi::SSL_CTX_new(method.to_fn()) }; - assert!(ctx != ptr::null()); + let ctx = unsafe { ffi::SSL_CTX_new(method.to_raw()) }; + if ctx == ptr::null() { + return Err(SslError::get().unwrap()); + } - SslCtx { - ctx: ctx + Ok(SslContext { ctx: ctx }) + } + + pub fn new(method: SslMethod) -> SslContext { + match SslContext::try_new(method) { + Ok(ctx) => ctx, + Err(err) => fail!("Error creating SSL context: {:?}", err) } } + // TODO: support callback (see SSL_CTX_set_ex_data) pub fn set_verify(&mut self, mode: SslVerifyMode) { - unsafe { ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None) } + unsafe { + ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None); + } } - pub fn set_verify_locations(&mut self, CAfile: &str) { - do CAfile.with_c_str |CAfile| { - unsafe { ffi::SSL_CTX_load_verify_locations(self.ctx, CAfile, - ptr::null()); } + pub fn set_CA_file(&mut self, file: &str) -> Option { + let ret = do file.with_c_str |file| { + unsafe { + ffi::SSL_CTX_load_verify_locations(self.ctx, file, ptr::null()) + } + }; + + if ret == 0 { + Some(SslError::get().unwrap()) + } else { + None } } } -pub enum SslVerifyMode { - SslVerifyNone = ffi::SSL_VERIFY_NONE, - SslVerifyPeer = ffi::SSL_VERIFY_PEER +struct Ssl { + ssl: *ffi::SSL } -#[deriving(Eq, FromPrimitive)] -enum SslError { +impl Drop for Ssl { + fn drop(&mut self) { + unsafe { ffi::SSL_free(self.ssl) } + } +} + +impl Ssl { + fn try_new(ctx: &SslContext) -> Result { + let ssl = unsafe { ffi::SSL_new(ctx.ctx) }; + if ssl == ptr::null() { + return Err(SslError::get().unwrap()); + } + let ssl = Ssl { ssl: ssl }; + + let rbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) }; + if rbio == ptr::null() { + return Err(SslError::get().unwrap()); + } + + let wbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) }; + if wbio == ptr::null() { + unsafe { ffi::BIO_free_all(rbio) } + return Err(SslError::get().unwrap()); + } + + unsafe { ffi::SSL_set_bio(ssl.ssl, rbio, wbio) } + Ok(ssl) + } + + fn get_rbio<'a>(&'a self) -> MemBio<'a> { + let bio = unsafe { ffi::SSL_get_rbio(self.ssl) }; + assert!(bio != ptr::null()); + + MemBio { + ssl: self, + bio: bio + } + } + + fn get_wbio<'a>(&'a self) -> MemBio<'a> { + let bio = unsafe { ffi::SSL_get_wbio(self.ssl) }; + assert!(bio != ptr::null()); + + MemBio { + ssl: self, + bio: bio + } + } + + fn connect(&self) -> c_int { + unsafe { ffi::SSL_connect(self.ssl) } + } + + fn read(&self, buf: &mut [u8]) -> c_int { + unsafe { ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void, + buf.len() as c_int) } + } + + fn write(&self, buf: &[u8]) -> c_int { + unsafe { ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void, + buf.len() as c_int) } + } + + fn get_error(&self, ret: c_int) -> LibSslError { + let err = unsafe { ffi::SSL_get_error(self.ssl, ret) }; + match FromPrimitive::from_int(err as int) { + Some(err) => err, + None => unreachable!() + } + } +} + +#[deriving(FromPrimitive)] +enum LibSslError { ErrorNone = ffi::SSL_ERROR_NONE, ErrorSsl = ffi::SSL_ERROR_SSL, ErrorWantRead = ffi::SSL_ERROR_WANT_READ, @@ -97,144 +194,72 @@ enum SslError { ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT, } -struct Ssl { - ssl: *ffi::SSL -} - -impl Drop for Ssl { - fn drop(&mut self) { - unsafe { ffi::SSL_free(self.ssl); } - } -} - -impl Ssl { - fn new(ctx: &SslCtx) -> Ssl { - let ssl = unsafe { ffi::SSL_new(ctx.ctx) }; - assert!(ssl != ptr::null()); - - Ssl { ssl: ssl } - } - - fn set_bio(&self, rbio: &MemBio, wbio: &MemBio) { - unsafe { ffi::SSL_set_bio(self.ssl, rbio.bio, wbio.bio); } - } - - fn set_connect_state(&self) { - unsafe { ffi::SSL_set_connect_state(self.ssl); } - } - - fn connect(&self) -> int { - unsafe { ffi::SSL_connect(self.ssl) as int } - } - - fn get_error(&self, ret: int) -> SslError { - let err = unsafe { ffi::SSL_get_error(self.ssl, ret as c_int) }; - match FromPrimitive::from_int(err as int) { - Some(err) => err, - None => fail2!("Unknown error {}", err) - } - } - - fn read(&self, buf: &[u8]) -> int { - unsafe { - ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void, - buf.len() as c_int) as int - } - } - - fn write(&self, buf: &[u8]) -> int { - unsafe { - ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void, - buf.len() as c_int) as int - } - } - - fn shutdown(&self) -> int { - unsafe { ffi::SSL_shutdown(self.ssl) as int } - } -} - -// BIOs are freed by SSL_free -struct MemBio { +struct MemBio<'self> { + ssl: &'self Ssl, bio: *ffi::BIO } -impl MemBio { - fn new() -> MemBio { - let bio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) }; - assert!(bio != ptr::null()); +impl<'self> MemBio<'self> { + fn read(&self, buf: &mut [u8]) -> Option { + let ret = unsafe { + ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void, + buf.len() as c_int) + }; - MemBio { bio: bio } + if ret < 0 { + None + } else { + Some(ret as uint) + } } fn write(&self, buf: &[u8]) { - unsafe { - let ret = ffi::BIO_write(self.bio, - vec::raw::to_ptr(buf) as *c_void, - buf.len() as c_int); - if ret < 0 { - fail2!("write returned {}", ret); - } - } - } - - fn read(&self, buf: &[u8]) -> uint { - unsafe { - let ret = ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void, - buf.len() as c_int); - if ret < 0 { - 0 - } else { - ret as uint - } - } + let ret = unsafe { + ffi::BIO_write(self.bio, vec::raw::to_ptr(buf) as *c_void, + buf.len() as c_int) + }; + assert_eq!(buf.len(), ret as uint); } } pub struct SslStream { - priv ctx: SslCtx, + priv stream: S, priv ssl: Ssl, - priv buf: ~[u8], - priv rbio: MemBio, - priv wbio: MemBio, - priv stream: S + priv buf: ~[u8] } impl SslStream { - pub fn new(ctx: SslCtx, stream: S) -> Result, uint> { - let ssl = Ssl::new(&ctx); + pub fn try_new(ctx: &SslContext, stream: S) -> Result, + SslError> { + let ssl = match Ssl::try_new(ctx) { + Ok(ssl) => ssl, + Err(err) => return Err(err) + }; - let rbio = MemBio::new(); - let wbio = MemBio::new(); - - ssl.set_bio(&rbio, &wbio); - ssl.set_connect_state(); - - let mut stream = SslStream { - ctx: ctx, + let mut ssl = SslStream { + stream: stream, ssl: ssl, - // Max record size for SSLv3/TLSv1 is 16k - buf: vec::from_elem(16 * 1024, 0u8), - rbio: rbio, - wbio: wbio, - stream: stream + // Maximum TLS record size is 16k + buf: vec::from_elem(16 * 1024, 0u8) }; - let ret = do stream.in_retry_wrapper |ssl| { - ssl.ssl.connect() - }; - - match ret { - Ok(_) => Ok(stream), - // FIXME - Err(_err) => Err(unsafe { ffi::ERR_get_error() as uint }) + match ssl.in_retry_wrapper(|ssl| { ssl.connect() }) { + Ok(_) => Ok(ssl), + Err(err) => Err(err) } } - fn in_retry_wrapper(&mut self, blk: &fn(&mut SslStream) -> int) - -> Result { + pub fn new(ctx: &SslContext, stream: S) -> SslStream { + match SslStream::try_new(ctx, stream) { + Ok(stream) => stream, + Err(err) => fail!("Error creating SSL stream: {:?}", err) + } + } + + fn in_retry_wrapper(&mut self, blk: &fn(&Ssl) -> c_int) + -> Result { loop { - let ret = blk(self); + let ret = blk(&self.ssl); if ret > 0 { return Ok(ret); } @@ -243,34 +268,24 @@ impl SslStream { ErrorWantRead => { self.flush(); match self.stream.read(self.buf) { - Some(len) => self.rbio.write(self.buf.slice_to(len)), - None => return Err(ErrorZeroReturn) // FIXME + Some(len) => + self.ssl.get_rbio().write(self.buf.slice_to(len)), + None => return Err(StreamEof) } } ErrorWantWrite => self.flush(), - err => return Err(err) + ErrorZeroReturn => return Err(SslSessionClosed), + ErrorSsl => return Err(SslError::get().unwrap()), + _ => unreachable!() } } } fn write_through(&mut self) { loop { - let len = self.wbio.read(self.buf); - if len == 0 { - return; - } - self.stream.write(self.buf.slice_to(len)); - } - } - - pub fn shutdown(&mut self) { - loop { - let ret = do self.in_retry_wrapper |ssl| { - ssl.ssl.shutdown() - }; - - if ret != Ok(0) { - break; + match self.ssl.get_wbio().read(self.buf) { + Some(len) => self.stream.write(self.buf.slice_to(len)), + None => break } } } @@ -278,13 +293,10 @@ impl SslStream { impl Reader for SslStream { fn read(&mut self, buf: &mut [u8]) -> Option { - let ret = do self.in_retry_wrapper |ssl| { - ssl.ssl.read(buf) - }; - - match ret { - Ok(num) => Some(num as uint), - Err(_) => None + match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) { + Ok(len) => Some(len as uint), + Err(StreamEof) | Err(SslSessionClosed) => None, + _ => unreachable!() } } @@ -295,25 +307,26 @@ impl Reader for SslStream { impl Writer for SslStream { fn write(&mut self, buf: &[u8]) { - let ret = do self.in_retry_wrapper |ssl| { - ssl.ssl.write(buf) - }; - - match ret { - Ok(_) => (), - Err(err) => fail2!("Write error: {:?}", err) + let mut start = 0; + while start < buf.len() { + let ret = do self.in_retry_wrapper |ssl| { + ssl.write(buf.slice_from(start)) + }; + match ret { + Ok(len) => start += len as uint, + _ => unreachable!() + } + self.write_through(); } - - self.write_through(); } fn flush(&mut self) { self.write_through(); - self.stream.flush(); + self.stream.flush() } } -impl Decorator for SslStream { +impl Decorator for SslStream { fn inner(self) -> S { self.stream } diff --git a/tests.rs b/tests.rs index 639ce1b1..b167cda8 100644 --- a/tests.rs +++ b/tests.rs @@ -3,37 +3,37 @@ use std::rt::io::extensions::ReaderUtil; use std::rt::io::net::tcp::TcpStream; use std::str; -use super::{Sslv23, SslCtx, SslStream, SslVerifyPeer}; +use super::{Sslv23, SslContext, SslStream, SslVerifyPeer}; #[test] fn test_new_ctx() { - SslCtx::new(Sslv23); + SslContext::new(Sslv23); } #[test] fn test_new_sslstream() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); - SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); + SslStream::new(&SslContext::new(Sslv23), stream); } #[test] fn test_verify_untrusted() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); - let mut ctx = SslCtx::new(Sslv23); + let mut ctx = SslContext::new(Sslv23); ctx.set_verify(SslVerifyPeer); - match SslStream::new(ctx, stream) { + match SslStream::try_new(&ctx, stream) { Ok(_) => fail2!("expected failure"), - Err(err) => println!("error {}", err) + Err(err) => println!("error {:?}", err) } } #[test] fn test_verify_trusted() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); - let mut ctx = SslCtx::new(Sslv23); + let mut ctx = SslContext::new(Sslv23); ctx.set_verify(SslVerifyPeer); - ctx.set_verify_locations("cert.pem"); - match SslStream::new(ctx, stream) { + assert!(ctx.set_CA_file("cert.pem").is_none()); + match SslStream::try_new(&ctx, stream) { Ok(_) => (), Err(err) => fail2!("Expected success, got {:?}", err) } @@ -42,18 +42,17 @@ fn test_verify_trusted() { #[test] fn test_write() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); - let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); + let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); stream.write("hello".as_bytes()); stream.flush(); stream.write(" there".as_bytes()); stream.flush(); - stream.shutdown(); } #[test] fn test_read() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); - let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); + let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); stream.write("GET /\r\n\r\n".as_bytes()); stream.flush(); let buf = stream.read_to_end();