Major rewrite for better error handling

This commit is contained in:
Steven Fackler 2013-10-21 22:51:18 -07:00
parent a42d5261f9
commit 302590c2b5
4 changed files with 223 additions and 190 deletions

18
error.rs Normal file
View File

@ -0,0 +1,18 @@
use std::libc::c_ulong;
use super::ffi;
pub enum SslError {
StreamEof,
SslSessionClosed,
UnknownError(c_ulong)
}
impl SslError {
pub fn get() -> Option<SslError> {
match unsafe { ffi::ERR_get_error() } {
0 => None,
err => Some(UnknownError(err))
}
}
}

3
ffi.rs
View File

@ -45,6 +45,8 @@ externfn!(fn SSL_CTX_load_verify_locations(ctx: *SSL_CTX, CAfile: *c_char,
externfn!(fn SSL_new(ctx: *SSL_CTX) -> *SSL) externfn!(fn SSL_new(ctx: *SSL_CTX) -> *SSL)
externfn!(fn SSL_free(ssl: *SSL)) externfn!(fn SSL_free(ssl: *SSL))
externfn!(fn SSL_set_bio(ssl: *SSL, rbio: *BIO, wbio: *BIO)) externfn!(fn SSL_set_bio(ssl: *SSL, rbio: *BIO, wbio: *BIO))
externfn!(fn SSL_get_rbio(ssl: *SSL) -> *BIO)
externfn!(fn SSL_get_wbio(ssl: *SSL) -> *BIO)
externfn!(fn SSL_set_connect_state(ssl: *SSL)) externfn!(fn SSL_set_connect_state(ssl: *SSL))
externfn!(fn SSL_connect(ssl: *SSL) -> c_int) externfn!(fn SSL_connect(ssl: *SSL) -> c_int)
externfn!(fn SSL_get_error(ssl: *SSL, ret: c_int) -> c_int) externfn!(fn SSL_get_error(ssl: *SSL, ret: c_int) -> c_int)
@ -54,5 +56,6 @@ externfn!(fn SSL_shutdown(ssl: *SSL) -> c_int)
externfn!(fn BIO_s_mem() -> *BIO_METHOD) externfn!(fn BIO_s_mem() -> *BIO_METHOD)
externfn!(fn BIO_new(type_: *BIO_METHOD) -> *BIO) externfn!(fn BIO_new(type_: *BIO_METHOD) -> *BIO)
externfn!(fn BIO_free_all(a: *BIO))
externfn!(fn BIO_read(b: *BIO, buf: *c_void, len: c_int) -> c_int) externfn!(fn BIO_read(b: *BIO, buf: *c_void, len: c_int) -> c_int)
externfn!(fn BIO_write(b: *BIO, buf: *c_void, len: c_int) -> c_int) externfn!(fn BIO_write(b: *BIO, buf: *c_void, len: c_int) -> c_int)

369
lib.rs
View File

@ -1,15 +1,19 @@
use std::rt::io::{Reader, Writer, Stream, Decorator};
use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release};
use std::task;
use std::ptr;
use std::vec;
use std::libc::{c_int, c_void}; use std::libc::{c_int, c_void};
use std::ptr;
use std::task;
use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release};
use std::rt::io::{Stream, Reader, Writer, Decorator};
use std::vec;
mod ffi; use error::{SslError, SslSessionClosed, StreamEof};
pub mod error;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
mod ffi;
static mut STARTED_INIT: AtomicBool = INIT_ATOMIC_BOOL; static mut STARTED_INIT: AtomicBool = INIT_ATOMIC_BOOL;
static mut FINISHED_INIT: AtomicBool = INIT_ATOMIC_BOOL; static mut FINISHED_INIT: AtomicBool = INIT_ATOMIC_BOOL;
@ -35,7 +39,7 @@ pub enum SslMethod {
} }
impl SslMethod { impl SslMethod {
unsafe fn to_fn(&self) -> *ffi::SSL_METHOD { unsafe fn to_raw(&self) -> *ffi::SSL_METHOD {
match *self { match *self {
Sslv2 => ffi::SSLv2_method(), Sslv2 => ffi::SSLv2_method(),
Sslv3 => ffi::SSLv3_method(), Sslv3 => ffi::SSLv3_method(),
@ -45,47 +49,140 @@ impl SslMethod {
} }
} }
pub struct SslCtx { pub enum SslVerifyMode {
SslVerifyPeer = ffi::SSL_VERIFY_PEER,
SslVerifyNone = ffi::SSL_VERIFY_NONE
}
pub struct SslContext {
priv ctx: *ffi::SSL_CTX priv ctx: *ffi::SSL_CTX
} }
impl Drop for SslCtx { impl Drop for SslContext {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { ffi::SSL_CTX_free(self.ctx); } unsafe { ffi::SSL_CTX_free(self.ctx) }
} }
} }
impl SslCtx { impl SslContext {
pub fn new(method: SslMethod) -> SslCtx { pub fn try_new(method: SslMethod) -> Result<SslContext, SslError> {
init(); init();
let ctx = unsafe { ffi::SSL_CTX_new(method.to_fn()) }; let ctx = unsafe { ffi::SSL_CTX_new(method.to_raw()) };
assert!(ctx != ptr::null()); if ctx == ptr::null() {
return Err(SslError::get().unwrap());
}
SslCtx { Ok(SslContext { ctx: ctx })
ctx: ctx }
pub fn new(method: SslMethod) -> SslContext {
match SslContext::try_new(method) {
Ok(ctx) => ctx,
Err(err) => fail!("Error creating SSL context: {:?}", err)
} }
} }
// TODO: support callback (see SSL_CTX_set_ex_data)
pub fn set_verify(&mut self, mode: SslVerifyMode) { pub fn set_verify(&mut self, mode: SslVerifyMode) {
unsafe { ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None) } unsafe {
ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None);
}
} }
pub fn set_verify_locations(&mut self, CAfile: &str) { pub fn set_CA_file(&mut self, file: &str) -> Option<SslError> {
do CAfile.with_c_str |CAfile| { let ret = do file.with_c_str |file| {
unsafe { ffi::SSL_CTX_load_verify_locations(self.ctx, CAfile, unsafe {
ptr::null()); } ffi::SSL_CTX_load_verify_locations(self.ctx, file, ptr::null())
}
};
if ret == 0 {
Some(SslError::get().unwrap())
} else {
None
} }
} }
} }
pub enum SslVerifyMode { struct Ssl {
SslVerifyNone = ffi::SSL_VERIFY_NONE, ssl: *ffi::SSL
SslVerifyPeer = ffi::SSL_VERIFY_PEER
} }
#[deriving(Eq, FromPrimitive)] impl Drop for Ssl {
enum SslError { fn drop(&mut self) {
unsafe { ffi::SSL_free(self.ssl) }
}
}
impl Ssl {
fn try_new(ctx: &SslContext) -> Result<Ssl, SslError> {
let ssl = unsafe { ffi::SSL_new(ctx.ctx) };
if ssl == ptr::null() {
return Err(SslError::get().unwrap());
}
let ssl = Ssl { ssl: ssl };
let rbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) };
if rbio == ptr::null() {
return Err(SslError::get().unwrap());
}
let wbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) };
if wbio == ptr::null() {
unsafe { ffi::BIO_free_all(rbio) }
return Err(SslError::get().unwrap());
}
unsafe { ffi::SSL_set_bio(ssl.ssl, rbio, wbio) }
Ok(ssl)
}
fn get_rbio<'a>(&'a self) -> MemBio<'a> {
let bio = unsafe { ffi::SSL_get_rbio(self.ssl) };
assert!(bio != ptr::null());
MemBio {
ssl: self,
bio: bio
}
}
fn get_wbio<'a>(&'a self) -> MemBio<'a> {
let bio = unsafe { ffi::SSL_get_wbio(self.ssl) };
assert!(bio != ptr::null());
MemBio {
ssl: self,
bio: bio
}
}
fn connect(&self) -> c_int {
unsafe { ffi::SSL_connect(self.ssl) }
}
fn read(&self, buf: &mut [u8]) -> c_int {
unsafe { ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int) }
}
fn write(&self, buf: &[u8]) -> c_int {
unsafe { ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int) }
}
fn get_error(&self, ret: c_int) -> LibSslError {
let err = unsafe { ffi::SSL_get_error(self.ssl, ret) };
match FromPrimitive::from_int(err as int) {
Some(err) => err,
None => unreachable!()
}
}
}
#[deriving(FromPrimitive)]
enum LibSslError {
ErrorNone = ffi::SSL_ERROR_NONE, ErrorNone = ffi::SSL_ERROR_NONE,
ErrorSsl = ffi::SSL_ERROR_SSL, ErrorSsl = ffi::SSL_ERROR_SSL,
ErrorWantRead = ffi::SSL_ERROR_WANT_READ, ErrorWantRead = ffi::SSL_ERROR_WANT_READ,
@ -97,144 +194,72 @@ enum SslError {
ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT, ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT,
} }
struct Ssl { struct MemBio<'self> {
ssl: *ffi::SSL ssl: &'self Ssl,
}
impl Drop for Ssl {
fn drop(&mut self) {
unsafe { ffi::SSL_free(self.ssl); }
}
}
impl Ssl {
fn new(ctx: &SslCtx) -> Ssl {
let ssl = unsafe { ffi::SSL_new(ctx.ctx) };
assert!(ssl != ptr::null());
Ssl { ssl: ssl }
}
fn set_bio(&self, rbio: &MemBio, wbio: &MemBio) {
unsafe { ffi::SSL_set_bio(self.ssl, rbio.bio, wbio.bio); }
}
fn set_connect_state(&self) {
unsafe { ffi::SSL_set_connect_state(self.ssl); }
}
fn connect(&self) -> int {
unsafe { ffi::SSL_connect(self.ssl) as int }
}
fn get_error(&self, ret: int) -> SslError {
let err = unsafe { ffi::SSL_get_error(self.ssl, ret as c_int) };
match FromPrimitive::from_int(err as int) {
Some(err) => err,
None => fail2!("Unknown error {}", err)
}
}
fn read(&self, buf: &[u8]) -> int {
unsafe {
ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int) as int
}
}
fn write(&self, buf: &[u8]) -> int {
unsafe {
ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int) as int
}
}
fn shutdown(&self) -> int {
unsafe { ffi::SSL_shutdown(self.ssl) as int }
}
}
// BIOs are freed by SSL_free
struct MemBio {
bio: *ffi::BIO bio: *ffi::BIO
} }
impl MemBio { impl<'self> MemBio<'self> {
fn new() -> MemBio { fn read(&self, buf: &mut [u8]) -> Option<uint> {
let bio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) }; let ret = unsafe {
assert!(bio != ptr::null()); ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int)
};
MemBio { bio: bio } if ret < 0 {
None
} else {
Some(ret as uint)
}
} }
fn write(&self, buf: &[u8]) { fn write(&self, buf: &[u8]) {
unsafe { let ret = unsafe {
let ret = ffi::BIO_write(self.bio, ffi::BIO_write(self.bio, vec::raw::to_ptr(buf) as *c_void,
vec::raw::to_ptr(buf) as *c_void, buf.len() as c_int)
buf.len() as c_int); };
if ret < 0 { assert_eq!(buf.len(), ret as uint);
fail2!("write returned {}", ret);
}
}
}
fn read(&self, buf: &[u8]) -> uint {
unsafe {
let ret = ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int);
if ret < 0 {
0
} else {
ret as uint
}
}
} }
} }
pub struct SslStream<S> { pub struct SslStream<S> {
priv ctx: SslCtx, priv stream: S,
priv ssl: Ssl, priv ssl: Ssl,
priv buf: ~[u8], priv buf: ~[u8]
priv rbio: MemBio,
priv wbio: MemBio,
priv stream: S
} }
impl<S: Stream> SslStream<S> { impl<S: Stream> SslStream<S> {
pub fn new(ctx: SslCtx, stream: S) -> Result<SslStream<S>, uint> { pub fn try_new(ctx: &SslContext, stream: S) -> Result<SslStream<S>,
let ssl = Ssl::new(&ctx); SslError> {
let ssl = match Ssl::try_new(ctx) {
Ok(ssl) => ssl,
Err(err) => return Err(err)
};
let rbio = MemBio::new(); let mut ssl = SslStream {
let wbio = MemBio::new(); stream: stream,
ssl.set_bio(&rbio, &wbio);
ssl.set_connect_state();
let mut stream = SslStream {
ctx: ctx,
ssl: ssl, ssl: ssl,
// Max record size for SSLv3/TLSv1 is 16k // Maximum TLS record size is 16k
buf: vec::from_elem(16 * 1024, 0u8), buf: vec::from_elem(16 * 1024, 0u8)
rbio: rbio,
wbio: wbio,
stream: stream
}; };
let ret = do stream.in_retry_wrapper |ssl| { match ssl.in_retry_wrapper(|ssl| { ssl.connect() }) {
ssl.ssl.connect() Ok(_) => Ok(ssl),
}; Err(err) => Err(err)
match ret {
Ok(_) => Ok(stream),
// FIXME
Err(_err) => Err(unsafe { ffi::ERR_get_error() as uint })
} }
} }
fn in_retry_wrapper(&mut self, blk: &fn(&mut SslStream<S>) -> int) pub fn new(ctx: &SslContext, stream: S) -> SslStream<S> {
-> Result<int, SslError> { match SslStream::try_new(ctx, stream) {
Ok(stream) => stream,
Err(err) => fail!("Error creating SSL stream: {:?}", err)
}
}
fn in_retry_wrapper(&mut self, blk: &fn(&Ssl) -> c_int)
-> Result<c_int, SslError> {
loop { loop {
let ret = blk(self); let ret = blk(&self.ssl);
if ret > 0 { if ret > 0 {
return Ok(ret); return Ok(ret);
} }
@ -243,34 +268,24 @@ impl<S: Stream> SslStream<S> {
ErrorWantRead => { ErrorWantRead => {
self.flush(); self.flush();
match self.stream.read(self.buf) { match self.stream.read(self.buf) {
Some(len) => self.rbio.write(self.buf.slice_to(len)), Some(len) =>
None => return Err(ErrorZeroReturn) // FIXME self.ssl.get_rbio().write(self.buf.slice_to(len)),
None => return Err(StreamEof)
} }
} }
ErrorWantWrite => self.flush(), ErrorWantWrite => self.flush(),
err => return Err(err) ErrorZeroReturn => return Err(SslSessionClosed),
ErrorSsl => return Err(SslError::get().unwrap()),
_ => unreachable!()
} }
} }
} }
fn write_through(&mut self) { fn write_through(&mut self) {
loop { loop {
let len = self.wbio.read(self.buf); match self.ssl.get_wbio().read(self.buf) {
if len == 0 { Some(len) => self.stream.write(self.buf.slice_to(len)),
return; None => break
}
self.stream.write(self.buf.slice_to(len));
}
}
pub fn shutdown(&mut self) {
loop {
let ret = do self.in_retry_wrapper |ssl| {
ssl.ssl.shutdown()
};
if ret != Ok(0) {
break;
} }
} }
} }
@ -278,13 +293,10 @@ impl<S: Stream> SslStream<S> {
impl<S: Stream> Reader for SslStream<S> { impl<S: Stream> Reader for SslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> Option<uint> { fn read(&mut self, buf: &mut [u8]) -> Option<uint> {
let ret = do self.in_retry_wrapper |ssl| { match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
ssl.ssl.read(buf) Ok(len) => Some(len as uint),
}; Err(StreamEof) | Err(SslSessionClosed) => None,
_ => unreachable!()
match ret {
Ok(num) => Some(num as uint),
Err(_) => None
} }
} }
@ -295,25 +307,26 @@ impl<S: Stream> Reader for SslStream<S> {
impl<S: Stream> Writer for SslStream<S> { impl<S: Stream> Writer for SslStream<S> {
fn write(&mut self, buf: &[u8]) { fn write(&mut self, buf: &[u8]) {
let ret = do self.in_retry_wrapper |ssl| { let mut start = 0;
ssl.ssl.write(buf) while start < buf.len() {
}; let ret = do self.in_retry_wrapper |ssl| {
ssl.write(buf.slice_from(start))
match ret { };
Ok(_) => (), match ret {
Err(err) => fail2!("Write error: {:?}", err) Ok(len) => start += len as uint,
_ => unreachable!()
}
self.write_through();
} }
self.write_through();
} }
fn flush(&mut self) { fn flush(&mut self) {
self.write_through(); self.write_through();
self.stream.flush(); self.stream.flush()
} }
} }
impl<S: Stream> Decorator<S> for SslStream<S> { impl<S> Decorator<S> for SslStream<S> {
fn inner(self) -> S { fn inner(self) -> S {
self.stream self.stream
} }

View File

@ -3,37 +3,37 @@ use std::rt::io::extensions::ReaderUtil;
use std::rt::io::net::tcp::TcpStream; use std::rt::io::net::tcp::TcpStream;
use std::str; use std::str;
use super::{Sslv23, SslCtx, SslStream, SslVerifyPeer}; use super::{Sslv23, SslContext, SslStream, SslVerifyPeer};
#[test] #[test]
fn test_new_ctx() { fn test_new_ctx() {
SslCtx::new(Sslv23); SslContext::new(Sslv23);
} }
#[test] #[test]
fn test_new_sslstream() { fn test_new_sslstream() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); SslStream::new(&SslContext::new(Sslv23), stream);
} }
#[test] #[test]
fn test_verify_untrusted() { fn test_verify_untrusted() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut ctx = SslCtx::new(Sslv23); let mut ctx = SslContext::new(Sslv23);
ctx.set_verify(SslVerifyPeer); ctx.set_verify(SslVerifyPeer);
match SslStream::new(ctx, stream) { match SslStream::try_new(&ctx, stream) {
Ok(_) => fail2!("expected failure"), Ok(_) => fail2!("expected failure"),
Err(err) => println!("error {}", err) Err(err) => println!("error {:?}", err)
} }
} }
#[test] #[test]
fn test_verify_trusted() { fn test_verify_trusted() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut ctx = SslCtx::new(Sslv23); let mut ctx = SslContext::new(Sslv23);
ctx.set_verify(SslVerifyPeer); ctx.set_verify(SslVerifyPeer);
ctx.set_verify_locations("cert.pem"); assert!(ctx.set_CA_file("cert.pem").is_none());
match SslStream::new(ctx, stream) { match SslStream::try_new(&ctx, stream) {
Ok(_) => (), Ok(_) => (),
Err(err) => fail2!("Expected success, got {:?}", err) Err(err) => fail2!("Expected success, got {:?}", err)
} }
@ -42,18 +42,17 @@ fn test_verify_trusted() {
#[test] #[test]
fn test_write() { fn test_write() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
stream.write("hello".as_bytes()); stream.write("hello".as_bytes());
stream.flush(); stream.flush();
stream.write(" there".as_bytes()); stream.write(" there".as_bytes());
stream.flush(); stream.flush();
stream.shutdown();
} }
#[test] #[test]
fn test_read() { fn test_read() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap(); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
stream.write("GET /\r\n\r\n".as_bytes()); stream.write("GET /\r\n\r\n".as_bytes());
stream.flush(); stream.flush();
let buf = stream.read_to_end(); let buf = stream.read_to_end();