Properly propogate errors

This commit is contained in:
Steven Fackler 2014-03-09 17:01:37 -07:00
parent f5f10deadc
commit 4967880504
2 changed files with 24 additions and 18 deletions

View File

@ -18,6 +18,15 @@ static mut INIT: Once = ONCE_INIT;
static mut VERIFY_IDX: c_int = -1;
static mut MUTEXES: *mut ~[NativeMutex] = 0 as *mut ~[NativeMutex];
macro_rules! try_ssl(
($e:expr) => (
match $e {
Ok(ok) => ok,
Err(err) => return Err(StreamError(err))
}
)
)
fn init() {
unsafe {
INIT.doit(|| {
@ -480,14 +489,11 @@ impl<S: Stream> SslStream<S> {
match self.ssl.get_error(ret) {
ErrorWantRead => {
self.flush();
match self.stream.read(self.buf) {
Ok(len) =>
self.ssl.get_rbio().write(self.buf.slice_to(len)),
Err(err) => return Err(StreamError(err))
}
try_ssl!(self.flush());
let len = try_ssl!(self.stream.read(self.buf));
self.ssl.get_rbio().write(self.buf.slice_to(len));
}
ErrorWantWrite => { self.flush(); }
ErrorWantWrite => { try_ssl!(self.flush()) }
ErrorZeroReturn => return Err(SslSessionClosed),
ErrorSsl => return Err(SslError::get()),
_ => unreachable!()
@ -495,14 +501,14 @@ impl<S: Stream> SslStream<S> {
}
}
fn write_through(&mut self) {
fn write_through(&mut self) -> IoResult<()> {
loop {
// TODO propogate errors
match self.ssl.get_wbio().read(self.buf) {
Some(len) => self.stream.write(self.buf.slice_to(len)),
Some(len) => try!(self.stream.write(self.buf.slice_to(len))),
None => break
};
}
Ok(())
}
}
@ -533,13 +539,13 @@ impl<S: Stream> Writer for SslStream<S> {
Ok(len) => start += len as uint,
_ => unreachable!()
}
self.write_through();
try!(self.write_through());
}
Ok(())
}
fn flush(&mut self) -> IoResult<()> {
self.write_through();
try!(self.write_through());
self.stream.flush()
}
}

View File

@ -144,18 +144,18 @@ fn test_verify_trusted_get_error_err() {
fn test_write() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
stream.write("hello".as_bytes());
stream.flush();
stream.write(" there".as_bytes());
stream.flush();
stream.write("hello".as_bytes()).unwrap();
stream.flush().unwrap();
stream.write(" there".as_bytes()).unwrap();
stream.flush().unwrap();
}
#[test]
fn test_read() {
let stream = TcpStream::connect(FromStr::from_str("127.0.0.1:15418").unwrap()).unwrap();
let mut stream = SslStream::new(&SslContext::new(Sslv23), stream);
stream.write("GET /\r\n\r\n".as_bytes());
stream.flush();
stream.write("GET /\r\n\r\n".as_bytes()).unwrap();
stream.flush().unwrap();
let buf = stream.read_to_end().ok().expect("read error");
print!("{}", str::from_utf8(buf));
}