Major rewrite for better error handling

This commit is contained in:
Steven Fackler 2013-10-21 22:51:18 -07:00
parent a42d5261f9
commit 302590c2b5
4 changed files with 223 additions and 190 deletions

18
error.rs Normal file
View File

@ -0,0 +1,18 @@
use std::libc::c_ulong;
use super::ffi;
pub enum SslError {
StreamEof,
SslSessionClosed,
UnknownError(c_ulong)
}
impl SslError {
pub fn get() -> Option<SslError> {
match unsafe { ffi::ERR_get_error() } {
0 => None,
err => Some(UnknownError(err))
}
}
}

3
ffi.rs
View File

@ -45,6 +45,8 @@ externfn!(fn SSL_CTX_load_verify_locations(ctx: *SSL_CTX, CAfile: *c_char,
externfn!(fn SSL_new(ctx: *SSL_CTX) -> *SSL)
externfn!(fn SSL_free(ssl: *SSL))
externfn!(fn SSL_set_bio(ssl: *SSL, rbio: *BIO, wbio: *BIO))
externfn!(fn SSL_get_rbio(ssl: *SSL) -> *BIO)
externfn!(fn SSL_get_wbio(ssl: *SSL) -> *BIO)
externfn!(fn SSL_set_connect_state(ssl: *SSL))
externfn!(fn SSL_connect(ssl: *SSL) -> c_int)
externfn!(fn SSL_get_error(ssl: *SSL, ret: c_int) -> c_int)
@ -54,5 +56,6 @@ externfn!(fn SSL_shutdown(ssl: *SSL) -> c_int)
externfn!(fn BIO_s_mem() -> *BIO_METHOD)
externfn!(fn BIO_new(type_: *BIO_METHOD) -> *BIO)
externfn!(fn BIO_free_all(a: *BIO))
externfn!(fn BIO_read(b: *BIO, buf: *c_void, len: c_int) -> c_int)
externfn!(fn BIO_write(b: *BIO, buf: *c_void, len: c_int) -> c_int)

369
lib.rs
View File

@ -1,15 +1,19 @@
use std::rt::io::{Reader, Writer, Stream, Decorator};
use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release};
use std::task;
use std::ptr;
use std::vec;
use std::libc::{c_int, c_void};
use std::ptr;
use std::task;
use std::unstable::atomics::{AtomicBool, INIT_ATOMIC_BOOL, Acquire, Release};
use std::rt::io::{Stream, Reader, Writer, Decorator};
use std::vec;
mod ffi;
use error::{SslError, SslSessionClosed, StreamEof};
pub mod error;
#[cfg(test)]
mod tests;
mod ffi;
static mut STARTED_INIT: AtomicBool = INIT_ATOMIC_BOOL;
static mut FINISHED_INIT: AtomicBool = INIT_ATOMIC_BOOL;
@ -35,7 +39,7 @@ pub enum SslMethod {
}
impl SslMethod {
unsafe fn to_fn(&self) -> *ffi::SSL_METHOD {
unsafe fn to_raw(&self) -> *ffi::SSL_METHOD {
match *self {
Sslv2 => ffi::SSLv2_method(),
Sslv3 => ffi::SSLv3_method(),
@ -45,47 +49,140 @@ impl SslMethod {
}
}
pub struct SslCtx {
pub enum SslVerifyMode {
SslVerifyPeer = ffi::SSL_VERIFY_PEER,
SslVerifyNone = ffi::SSL_VERIFY_NONE
}
pub struct SslContext {
priv ctx: *ffi::SSL_CTX
}
impl Drop for SslCtx {
impl Drop for SslContext {
fn drop(&mut self) {
unsafe { ffi::SSL_CTX_free(self.ctx); }
unsafe { ffi::SSL_CTX_free(self.ctx) }
}
}
impl SslCtx {
pub fn new(method: SslMethod) -> SslCtx {
impl SslContext {
pub fn try_new(method: SslMethod) -> Result<SslContext, SslError> {
init();
let ctx = unsafe { ffi::SSL_CTX_new(method.to_fn()) };
assert!(ctx != ptr::null());
let ctx = unsafe { ffi::SSL_CTX_new(method.to_raw()) };
if ctx == ptr::null() {
return Err(SslError::get().unwrap());
}
SslCtx {
ctx: ctx
Ok(SslContext { ctx: ctx })
}
pub fn new(method: SslMethod) -> SslContext {
match SslContext::try_new(method) {
Ok(ctx) => ctx,
Err(err) => fail!("Error creating SSL context: {:?}", err)
}
}
// TODO: support callback (see SSL_CTX_set_ex_data)
pub fn set_verify(&mut self, mode: SslVerifyMode) {
unsafe { ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None) }
unsafe {
ffi::SSL_CTX_set_verify(self.ctx, mode as c_int, None);
}
}
pub fn set_verify_locations(&mut self, CAfile: &str) {
do CAfile.with_c_str |CAfile| {
unsafe { ffi::SSL_CTX_load_verify_locations(self.ctx, CAfile,
ptr::null()); }
pub fn set_CA_file(&mut self, file: &str) -> Option<SslError> {
let ret = do file.with_c_str |file| {
unsafe {
ffi::SSL_CTX_load_verify_locations(self.ctx, file, ptr::null())
}
};
if ret == 0 {
Some(SslError::get().unwrap())
} else {
None
}
}
}
pub enum SslVerifyMode {
SslVerifyNone = ffi::SSL_VERIFY_NONE,
SslVerifyPeer = ffi::SSL_VERIFY_PEER
struct Ssl {
ssl: *ffi::SSL
}
#[deriving(Eq, FromPrimitive)]
enum SslError {
impl Drop for Ssl {
fn drop(&mut self) {
unsafe { ffi::SSL_free(self.ssl) }
}
}
impl Ssl {
fn try_new(ctx: &SslContext) -> Result<Ssl, SslError> {
let ssl = unsafe { ffi::SSL_new(ctx.ctx) };
if ssl == ptr::null() {
return Err(SslError::get().unwrap());
}
let ssl = Ssl { ssl: ssl };
let rbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) };
if rbio == ptr::null() {
return Err(SslError::get().unwrap());
}
let wbio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) };
if wbio == ptr::null() {
unsafe { ffi::BIO_free_all(rbio) }
return Err(SslError::get().unwrap());
}
unsafe { ffi::SSL_set_bio(ssl.ssl, rbio, wbio) }
Ok(ssl)
}
fn get_rbio<'a>(&'a self) -> MemBio<'a> {
let bio = unsafe { ffi::SSL_get_rbio(self.ssl) };
assert!(bio != ptr::null());
MemBio {
ssl: self,
bio: bio
}
}
fn get_wbio<'a>(&'a self) -> MemBio<'a> {
let bio = unsafe { ffi::SSL_get_wbio(self.ssl) };
assert!(bio != ptr::null());
MemBio {
ssl: self,
bio: bio
}
}
fn connect(&self) -> c_int {
unsafe { ffi::SSL_connect(self.ssl) }
}
fn read(&self, buf: &mut [u8]) -> c_int {
unsafe { ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int) }
}
fn write(&self, buf: &[u8]) -> c_int {
unsafe { ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int) }
}
fn get_error(&self, ret: c_int) -> LibSslError {
let err = unsafe { ffi::SSL_get_error(self.ssl, ret) };
match FromPrimitive::from_int(err as int) {
Some(err) => err,
None => unreachable!()
}
}
}
#[deriving(FromPrimitive)]
enum LibSslError {
ErrorNone = ffi::SSL_ERROR_NONE,
ErrorSsl = ffi::SSL_ERROR_SSL,
ErrorWantRead = ffi::SSL_ERROR_WANT_READ,
@ -97,144 +194,72 @@ enum SslError {
ErrorWantAccept = ffi::SSL_ERROR_WANT_ACCEPT,
}
struct Ssl {
ssl: *ffi::SSL
}
impl Drop for Ssl {
fn drop(&mut self) {
unsafe { ffi::SSL_free(self.ssl); }
}
}
impl Ssl {
fn new(ctx: &SslCtx) -> Ssl {
let ssl = unsafe { ffi::SSL_new(ctx.ctx) };
assert!(ssl != ptr::null());
Ssl { ssl: ssl }
}
fn set_bio(&self, rbio: &MemBio, wbio: &MemBio) {
unsafe { ffi::SSL_set_bio(self.ssl, rbio.bio, wbio.bio); }
}
fn set_connect_state(&self) {
unsafe { ffi::SSL_set_connect_state(self.ssl); }
}
fn connect(&self) -> int {
unsafe { ffi::SSL_connect(self.ssl) as int }
}
fn get_error(&self, ret: int) -> SslError {
let err = unsafe { ffi::SSL_get_error(self.ssl, ret as c_int) };
match FromPrimitive::from_int(err as int) {
Some(err) => err,
None => fail2!("Unknown error {}", err)
}
}
fn read(&self, buf: &[u8]) -> int {
unsafe {
ffi::SSL_read(self.ssl, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int) as int
}
}
fn write(&self, buf: &[u8]) -> int {
unsafe {
ffi::SSL_write(self.ssl, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int) as int
}
}
fn shutdown(&self) -> int {
unsafe { ffi::SSL_shutdown(self.ssl) as int }
}
}
// BIOs are freed by SSL_free
struct MemBio {
struct MemBio<'self> {
ssl: &'self Ssl,
bio: *ffi::BIO
}
impl MemBio {
fn new() -> MemBio {
let bio = unsafe { ffi::BIO_new(ffi::BIO_s_mem()) };
assert!(bio != ptr::null());
impl<'self> MemBio<'self> {
fn read(&self, buf: &mut [u8]) -> Option<uint> {
let ret = unsafe {
ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int)
};
MemBio { bio: bio }
if ret < 0 {
None
} else {
Some(ret as uint)
}
}
fn write(&self, buf: &[u8]) {
unsafe {
let ret = ffi::BIO_write(self.bio,
vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int);
if ret < 0 {
fail2!("write returned {}", ret);
}
}
}
fn read(&self, buf: &[u8]) -> uint {
unsafe {
let ret = ffi::BIO_read(self.bio, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int);
if ret < 0 {
0
} else {
ret as uint
}
}
let ret = unsafe {
ffi::BIO_write(self.bio, vec::raw::to_ptr(buf) as *c_void,
buf.len() as c_int)
};
assert_eq!(buf.len(), ret as uint);
}
}
pub struct SslStream<S> {
priv ctx: SslCtx,
priv stream: S,
priv ssl: Ssl,
priv buf: ~[u8],
priv rbio: MemBio,
priv wbio: MemBio,
priv stream: S
priv buf: ~[u8]
}
impl<S: Stream> SslStream<S> {
pub fn new(ctx: SslCtx, stream: S) -> Result<SslStream<S>, uint> {
let ssl = Ssl::new(&ctx);
pub fn try_new(ctx: &SslContext, stream: S) -> Result<SslStream<S>,
SslError> {
let ssl = match Ssl::try_new(ctx) {
Ok(ssl) => ssl,
Err(err) => return Err(err)
};
let rbio = MemBio::new();
let wbio = MemBio::new();
ssl.set_bio(&rbio, &wbio);
ssl.set_connect_state();
let mut stream = SslStream {
ctx: ctx,
let mut ssl = SslStream {
stream: stream,
ssl: ssl,
// Max record size for SSLv3/TLSv1 is 16k
buf: vec::from_elem(16 * 1024, 0u8),
rbio: rbio,
wbio: wbio,
stream: stream
// Maximum TLS record size is 16k
buf: vec::from_elem(16 * 1024, 0u8)
};
let ret = do stream.in_retry_wrapper |ssl| {
ssl.ssl.connect()
};
match ret {
Ok(_) => Ok(stream),
// FIXME
Err(_err) => Err(unsafe { ffi::ERR_get_error() as uint })
match ssl.in_retry_wrapper(|ssl| { ssl.connect() }) {
Ok(_) => Ok(ssl),
Err(err) => Err(err)
}
}
fn in_retry_wrapper(&mut self, blk: &fn(&mut SslStream<S>) -> int)
-> Result<int, SslError> {
pub fn new(ctx: &SslContext, stream: S) -> SslStream<S> {
match SslStream::try_new(ctx, stream) {
Ok(stream) => stream,
Err(err) => fail!("Error creating SSL stream: {:?}", err)
}
}
fn in_retry_wrapper(&mut self, blk: &fn(&Ssl) -> c_int)
-> Result<c_int, SslError> {
loop {
let ret = blk(self);
let ret = blk(&self.ssl);
if ret > 0 {
return Ok(ret);
}
@ -243,34 +268,24 @@ impl<S: Stream> SslStream<S> {
ErrorWantRead => {
self.flush();
match self.stream.read(self.buf) {
Some(len) => self.rbio.write(self.buf.slice_to(len)),
None => return Err(ErrorZeroReturn) // FIXME
Some(len) =>
self.ssl.get_rbio().write(self.buf.slice_to(len)),
None => return Err(StreamEof)
}
}
ErrorWantWrite => self.flush(),
err => return Err(err)
ErrorZeroReturn => return Err(SslSessionClosed),
ErrorSsl => return Err(SslError::get().unwrap()),
_ => unreachable!()
}
}
}
fn write_through(&mut self) {
loop {
let len = self.wbio.read(self.buf);
if len == 0 {
return;
}
self.stream.write(self.buf.slice_to(len));
}
}
pub fn shutdown(&mut self) {
loop {
let ret = do self.in_retry_wrapper |ssl| {
ssl.ssl.shutdown()
};
if ret != Ok(0) {
break;
match self.ssl.get_wbio().read(self.buf) {
Some(len) => self.stream.write(self.buf.slice_to(len)),
None => break
}
}
}
@ -278,13 +293,10 @@ impl<S: Stream> SslStream<S> {
impl<S: Stream> Reader for SslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> Option<uint> {
let ret = do self.in_retry_wrapper |ssl| {
ssl.ssl.read(buf)
};
match ret {
Ok(num) => Some(num as uint),
Err(_) => None
match self.in_retry_wrapper(|ssl| { ssl.read(buf) }) {
Ok(len) => Some(len as uint),
Err(StreamEof) | Err(SslSessionClosed) => None,
_ => unreachable!()
}
}
@ -295,25 +307,26 @@ impl<S: Stream> Reader for SslStream<S> {
impl<S: Stream> Writer for SslStream<S> {
fn write(&mut self, buf: &[u8]) {
let ret = do self.in_retry_wrapper |ssl| {
ssl.ssl.write(buf)
};
match ret {
Ok(_) => (),
Err(err) => fail2!("Write error: {:?}", err)
let mut start = 0;
while start < buf.len() {
let ret = do self.in_retry_wrapper |ssl| {
ssl.write(buf.slice_from(start))
};
match ret {
Ok(len) => start += len as uint,
_ => unreachable!()
}
self.write_through();
}
self.write_through();
}
fn flush(&mut self) {
self.write_through();
self.stream.flush();
self.stream.flush()
}
}
impl<S: Stream> Decorator<S> for SslStream<S> {
impl<S> Decorator<S> for SslStream<S> {
fn inner(self) -> S {
self.stream
}

View File

@ -3,37 +3,37 @@ use std::rt::io::extensions::ReaderUtil;
use std::rt::io::net::tcp::TcpStream;
use std::str;
use super::{Sslv23, SslCtx, SslStream, SslVerifyPeer};
use super::{Sslv23, SslContext, SslStream, SslVerifyPeer};
#[test]
fn test_new_ctx() {
SslCtx::new(Sslv23);
SslContext::new(Sslv23);
}
#[test]
fn test_new_sslstream() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
SslStream::new(SslCtx::new(Sslv23), stream).unwrap();
SslStream::new(&SslContext::new(Sslv23), stream);
}
#[test]
fn test_verify_untrusted() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut ctx = SslCtx::new(Sslv23);
let mut ctx = SslContext::new(Sslv23);
ctx.set_verify(SslVerifyPeer);
match SslStream::new(ctx, stream) {
match SslStream::try_new(&ctx, stream) {
Ok(_) => fail2!("expected failure"),
Err(err) => println!("error {}", err)
Err(err) => println!("error {:?}", err)
}
}
#[test]
fn test_verify_trusted() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut ctx = SslCtx::new(Sslv23);
let mut ctx = SslContext::new(Sslv23);
ctx.set_verify(SslVerifyPeer);
ctx.set_verify_locations("cert.pem");
match SslStream::new(ctx, stream) {
assert!(ctx.set_CA_file("cert.pem").is_none());
match SslStream::try_new(&ctx, stream) {
Ok(_) => (),
Err(err) => fail2!("Expected success, got {:?}", err)
}
@ -42,18 +42,17 @@ fn test_verify_trusted() {
#[test]
fn test_write() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap();
let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
stream.write("hello".as_bytes());
stream.flush();
stream.write(" there".as_bytes());
stream.flush();
stream.shutdown();
}
#[test]
fn test_read() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut stream = SslStream::new(SslCtx::new(Sslv23), stream).unwrap();
let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
stream.write("GET /\r\n\r\n".as_bytes());
stream.flush();
let buf = stream.read_to_end();