From 49678805041cc2824ce54c9a0a8cf9fb5447838e Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 9 Mar 2014 17:01:37 -0700 Subject: [PATCH] Properly propogate errors --- ssl/mod.rs | 30 ++++++++++++++++++------------ ssl/tests.rs | 12 ++++++------ 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/ssl/mod.rs b/ssl/mod.rs index ad6dd023..e9b4e78b 100644 --- a/ssl/mod.rs +++ b/ssl/mod.rs @@ -18,6 +18,15 @@ static mut INIT: Once = ONCE_INIT; static mut VERIFY_IDX: c_int = -1; static mut MUTEXES: *mut ~[NativeMutex] = 0 as *mut ~[NativeMutex]; +macro_rules! try_ssl( + ($e:expr) => ( + match $e { + Ok(ok) => ok, + Err(err) => return Err(StreamError(err)) + } + ) +) + fn init() { unsafe { INIT.doit(|| { @@ -480,14 +489,11 @@ impl SslStream { match self.ssl.get_error(ret) { ErrorWantRead => { - self.flush(); - match self.stream.read(self.buf) { - Ok(len) => - self.ssl.get_rbio().write(self.buf.slice_to(len)), - Err(err) => return Err(StreamError(err)) - } + try_ssl!(self.flush()); + let len = try_ssl!(self.stream.read(self.buf)); + self.ssl.get_rbio().write(self.buf.slice_to(len)); } - ErrorWantWrite => { self.flush(); } + ErrorWantWrite => { try_ssl!(self.flush()) } ErrorZeroReturn => return Err(SslSessionClosed), ErrorSsl => return Err(SslError::get()), _ => unreachable!() @@ -495,14 +501,14 @@ impl SslStream { } } - fn write_through(&mut self) { + fn write_through(&mut self) -> IoResult<()> { loop { - // TODO propogate errors match self.ssl.get_wbio().read(self.buf) { - Some(len) => self.stream.write(self.buf.slice_to(len)), + Some(len) => try!(self.stream.write(self.buf.slice_to(len))), None => break }; } + Ok(()) } } @@ -533,13 +539,13 @@ impl Writer for SslStream { Ok(len) => start += len as uint, _ => unreachable!() } - self.write_through(); + try!(self.write_through()); } Ok(()) } fn flush(&mut self) -> IoResult<()> { - self.write_through(); + try!(self.write_through()); self.stream.flush() } } diff --git a/ssl/tests.rs b/ssl/tests.rs index 751ca7ab..c7f738f5 100644 --- a/ssl/tests.rs +++ b/ssl/tests.rs @@ -144,18 +144,18 @@ fn test_verify_trusted_get_error_err() { fn test_write() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); - stream.write("hello".as_bytes()); - stream.flush(); - stream.write(" there".as_bytes()); - stream.flush(); + stream.write("hello".as_bytes()).unwrap(); + stream.flush().unwrap(); + stream.write(" there".as_bytes()).unwrap(); + stream.flush().unwrap(); } #[test] fn test_read() { let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap(); let mut stream = SslStream::new(&SslContext::new(Sslv23), stream); - stream.write("GET /\r\n\r\n".as_bytes()); - stream.flush(); + stream.write("GET /\r\n\r\n".as_bytes()).unwrap(); + stream.flush().unwrap(); let buf = stream.read_to_end().ok().expect("read error"); print!("{}", str::from_utf8(buf)); }