Write through to underlying stream for every write call

cc #208
This commit is contained in:
Steven Fackler 2015-04-30 00:14:06 -07:00
parent b3f397c476
commit 73617dabfa
3 changed files with 35 additions and 11 deletions

View File

@ -899,14 +899,14 @@ impl<S: Read+Write> Read for SslStream<S> {
impl<S: Read+Write> Write for SslStream<S> { impl<S: Read+Write> Write for SslStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.in_retry_wrapper(|ssl| ssl.write(buf)) { let count = match self.in_retry_wrapper(|ssl| ssl.write(buf)) {
Ok(len) => Ok(len as usize), Ok(len) => len as usize,
Err(SslSessionClosed) => Ok(0), Err(SslSessionClosed) => 0,
Err(StreamError(e)) => return Err(e), Err(StreamError(e)) => return Err(e),
Err(e @ OpenSslErrors(_)) => { Err(e @ OpenSslErrors(_)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
Err(io::Error::new(io::ErrorKind::Other, e)) };
} try!(self.write_through());
} Ok(count)
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {

View File

@ -4,9 +4,7 @@ use std::net::TcpStream;
use std::io; use std::io;
use std::io::prelude::*; use std::io::prelude::*;
use std::path::Path; use std::path::Path;
#[cfg(feature = "npn")]
use std::net::TcpListener; use std::net::TcpListener;
#[cfg(feature = "npn")]
use std::thread; use std::thread;
use std::fs::File; use std::fs::File;
@ -17,7 +15,6 @@ use ssl::SslMethod::Sslv23;
use ssl::{SslContext, SslStream, VerifyCallback}; use ssl::{SslContext, SslStream, VerifyCallback};
use ssl::SSL_VERIFY_PEER; use ssl::SSL_VERIFY_PEER;
use x509::X509StoreContext; use x509::X509StoreContext;
#[cfg(feature = "npn")]
use x509::X509FileType; use x509::X509FileType;
use x509::X509; use x509::X509;
use crypto::pkey::PKey; use crypto::pkey::PKey;
@ -237,6 +234,34 @@ run_test!(verify_callback_data, |method, stream| {
} }
}); });
// Make sure every write call translates to a write call to the underlying socket.
#[test]
fn test_write_hits_stream() {
let listener = TcpListener::bind("localhost:0").unwrap();
let addr = listener.local_addr().unwrap();
let guard = thread::spawn(move || {
let ctx = SslContext::new(Sslv23).unwrap();
let stream = TcpStream::connect(addr).unwrap();
let mut stream = SslStream::new(&ctx, stream).unwrap();
stream.write_all(b"hello").unwrap();
stream
});
let mut ctx = SslContext::new(Sslv23).unwrap();
ctx.set_verify(SSL_VERIFY_PEER, None);
ctx.set_certificate_file(&Path::new("test/cert.pem"), X509FileType::PEM).unwrap();
ctx.set_private_key_file(&Path::new("test/key.pem"), X509FileType::PEM).unwrap();
let stream = listener.accept().unwrap().0;
let mut stream = SslStream::new_server(&ctx, stream).unwrap();
let mut buf = [0; 5];
assert_eq!(5, stream.read(&mut buf).unwrap());
assert_eq!(&b"hello"[..], &buf[..]);
guard.join().unwrap();
}
#[test] #[test]
fn test_set_certificate_and_private_key() { fn test_set_certificate_and_private_key() {
let key_path = Path::new("test/key.pem"); let key_path = Path::new("test/key.pem");

View File

@ -2,7 +2,6 @@ use serialize::hex::FromHex;
use std::io; use std::io;
use std::path::Path; use std::path::Path;
use std::fs::File; use std::fs::File;
use std::str;
use crypto::hash::Type::{SHA256}; use crypto::hash::Type::{SHA256};
use x509::{X509, X509Generator}; use x509::{X509, X509Generator};