Properly propogate errors

This commit is contained in:
Steven Fackler 2014-03-09 17:01:37 -07:00
parent f5f10deadc
commit 4967880504
2 changed files with 24 additions and 18 deletions

View File

@ -18,6 +18,15 @@ static mut INIT: Once = ONCE_INIT;
static mut VERIFY_IDX: c_int = -1; static mut VERIFY_IDX: c_int = -1;
static mut MUTEXES: *mut ~[NativeMutex] = 0 as *mut ~[NativeMutex]; static mut MUTEXES: *mut ~[NativeMutex] = 0 as *mut ~[NativeMutex];
macro_rules! try_ssl(
($e:expr) => (
match $e {
Ok(ok) => ok,
Err(err) => return Err(StreamError(err))
}
)
)
fn init() { fn init() {
unsafe { unsafe {
INIT.doit(|| { INIT.doit(|| {
@ -480,14 +489,11 @@ impl<S: Stream> SslStream<S> {
match self.ssl.get_error(ret) { match self.ssl.get_error(ret) {
ErrorWantRead => { ErrorWantRead => {
self.flush(); try_ssl!(self.flush());
match self.stream.read(self.buf) { let len = try_ssl!(self.stream.read(self.buf));
Ok(len) => self.ssl.get_rbio().write(self.buf.slice_to(len));
self.ssl.get_rbio().write(self.buf.slice_to(len)),
Err(err) => return Err(StreamError(err))
} }
} ErrorWantWrite => { try_ssl!(self.flush()) }
ErrorWantWrite => { self.flush(); }
ErrorZeroReturn => return Err(SslSessionClosed), ErrorZeroReturn => return Err(SslSessionClosed),
ErrorSsl => return Err(SslError::get()), ErrorSsl => return Err(SslError::get()),
_ => unreachable!() _ => unreachable!()
@ -495,14 +501,14 @@ impl<S: Stream> SslStream<S> {
} }
} }
fn write_through(&mut self) { fn write_through(&mut self) -> IoResult<()> {
loop { loop {
// TODO propogate errors
match self.ssl.get_wbio().read(self.buf) { match self.ssl.get_wbio().read(self.buf) {
Some(len) => self.stream.write(self.buf.slice_to(len)), Some(len) => try!(self.stream.write(self.buf.slice_to(len))),
None => break None => break
}; };
} }
Ok(())
} }
} }
@ -533,13 +539,13 @@ impl<S: Stream> Writer for SslStream<S> {
Ok(len) => start += len as uint, Ok(len) => start += len as uint,
_ => unreachable!() _ => unreachable!()
} }
self.write_through(); try!(self.write_through());
} }
Ok(()) Ok(())
} }
fn flush(&mut self) -> IoResult<()> { fn flush(&mut self) -> IoResult<()> {
self.write_through(); try!(self.write_through());
self.stream.flush() self.stream.flush()
} }
} }

View File

@ -144,18 +144,18 @@ fn test_verify_trusted_get_error_err() {
fn test_write() { fn test_write() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
stream.write("hello".as_bytes()); stream.write("hello".as_bytes()).unwrap();
stream.flush(); stream.flush().unwrap();
stream.write(" there".as_bytes()); stream.write(" there".as_bytes()).unwrap();
stream.flush(); stream.flush().unwrap();
} }
#[test] #[test]
fn test_read() { fn test_read() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
stream.write("GET /\r\n\r\n".as_bytes()); stream.write("GET /\r\n\r\n".as_bytes()).unwrap();
stream.flush(); stream.flush().unwrap();
let buf = stream.read_to_end().ok().expect("read error"); let buf = stream.read_to_end().ok().expect("read error");
print!("{}", str::from_utf8(buf)); print!("{}", str::from_utf8(buf));
} }