From 8e01f8d2502098497e642ee477d926a99ee619a8 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 20 Dec 2016 14:04:10 -0800 Subject: [PATCH] Handle zero-length reads/writes This commit adds some short-circuits for zero-length reads/writes to `SslStream`. Because OpenSSL returns 0 on error, then we could mistakenly confuse a 0-length success as an actual error, so we avoid writing or reading 0 bytes by returning quickly with a success. --- openssl/src/ssl/mod.rs | 14 ++++++++++++++ openssl/src/ssl/tests/mod.rs | 10 ++++++++++ 2 files changed, 24 insertions(+) diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 6e0c92c3..47c83453 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -1506,6 +1506,15 @@ impl SslStream { /// This is particularly useful with a nonblocking socket, where the error /// value will identify if OpenSSL is waiting on read or write readiness. pub fn ssl_read(&mut self, buf: &mut [u8]) -> Result { + // The intepretation of the return code here is a little odd with a + // zero-length write. OpenSSL will likely correctly report back to us + // that it read zero bytes, but zero is also the sentinel for "error". + // To avoid that confusion short-circuit that logic and return quickly + // if `buf` has a length of zero. + if buf.len() == 0 { + return Ok(0) + } + let ret = self.ssl.read(buf); if ret > 0 { Ok(ret as usize) @@ -1523,6 +1532,11 @@ impl SslStream { /// This is particularly useful with a nonblocking socket, where the error /// value will identify if OpenSSL is waiting on read or write readiness. pub fn ssl_write(&mut self, buf: &[u8]) -> Result { + // See above for why we short-circuit on zero-length buffers + if buf.len() == 0 { + return Ok(0) + } + let ret = self.ssl.write(buf); if ret > 0 { Ok(ret as usize) diff --git a/openssl/src/ssl/tests/mod.rs b/openssl/src/ssl/tests/mod.rs index 2f6bbe1f..66f9dca9 100644 --- a/openssl/src/ssl/tests/mod.rs +++ b/openssl/src/ssl/tests/mod.rs @@ -421,6 +421,16 @@ fn test_write() { stream.flush().unwrap(); } +#[test] +fn zero_length_buffers() { + let (_s, stream) = Server::new(); + let ctx = SslContext::builder(SslMethod::tls()).unwrap(); + let mut stream = Ssl::new(&ctx.build()).unwrap().connect(stream).unwrap(); + + assert_eq!(stream.write(b"").unwrap(), 0); + assert_eq!(stream.read(&mut []).unwrap(), 0); +} + run_test!(get_peer_certificate, |method, stream| { let ctx = SslContext::builder(method).unwrap(); let stream = Ssl::new(&ctx.build()).unwrap().connect(stream).unwrap();