Implement direct IO support

This commit is contained in:
Steven Fackler 2015-06-28 00:06:14 -07:00
parent 9b235a7b91
commit 1373a76ce1
2 changed files with 187 additions and 13 deletions

View File

@ -603,11 +603,6 @@ impl Ssl {
return Err(SslError::get()); return Err(SslError::get());
} }
let ssl = Ssl { ssl: ssl }; let ssl = Ssl { ssl: ssl };
let rbio = try!(MemBio::new());
let wbio = try!(MemBio::new());
unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) }
Ok(ssl) Ok(ssl)
} }
@ -769,6 +764,12 @@ impl IndirectStream<net::TcpStream> {
impl<S: Read+Write> IndirectStream<S> { impl<S: Read+Write> IndirectStream<S> {
fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> { fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
let ssl = try!(ssl.into_ssl()); let ssl = try!(ssl.into_ssl());
let rbio = try!(MemBio::new());
let wbio = try!(MemBio::new());
unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) }
Ok(IndirectStream { Ok(IndirectStream {
stream: stream, stream: stream,
ssl: Arc::new(ssl), ssl: Arc::new(ssl),
@ -852,27 +853,139 @@ impl<S: Read+Write> Write for IndirectStream<S> {
} }
} }
#[derive(Clone)]
struct DirectStream<S> {
stream: S,
ssl: Arc<Ssl>,
}
impl DirectStream<net::TcpStream> {
fn try_clone(&self) -> io::Result<DirectStream<net::TcpStream>> {
Ok(DirectStream {
stream: try!(self.stream.try_clone()),
ssl: self.ssl.clone(),
})
}
}
impl<S> DirectStream<S> {
fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
unsafe {
let bio = ffi::BIO_new_socket(sock, 0);
if bio == ptr::null_mut() {
return Err(SslError::get());
}
ffi::SSL_set_bio(ssl.ssl, bio, bio);
}
Ok(DirectStream {
stream: stream,
ssl: Arc::new(ssl),
})
}
fn new_client(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
let ssl = try!(DirectStream::new_base(ssl, stream, sock));
let ret = ssl.ssl.connect();
if ret > 0 {
Ok(ssl)
} else {
Err(ssl.make_error(ret))
}
}
fn new_server(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
let ssl = try!(DirectStream::new_base(ssl, stream, sock));
let ret = ssl.ssl.accept();
if ret > 0 {
Ok(ssl)
} else {
Err(ssl.make_error(ret))
}
}
fn make_error(&self, ret: c_int) -> SslError {
match self.ssl.get_error(ret) {
LibSslError::ErrorSsl => SslError::get(),
LibSslError::ErrorSyscall => {
let err = SslError::get();
let count = match err {
SslError::OpenSslErrors(ref v) => v.len(),
_ => unreachable!(),
};
if count == 0 {
if ret == 0 {
SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted,
"unexpected EOF observed"))
} else {
SslError::StreamError(io::Error::last_os_error())
}
} else {
err
}
}
err => panic!("unexpected error {:?} with ret {}", err, ret),
}
}
}
impl<S> Read for DirectStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let ret = self.ssl.read(buf);
if ret >= 0 {
return Ok(ret as usize);
}
match self.make_error(ret) {
SslError::StreamError(e) => Err(e),
e => Err(io::Error::new(io::ErrorKind::Other, e)),
}
}
}
impl<S: Write> Write for DirectStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let ret = self.ssl.write(buf);
if ret > 0 {
return Ok(ret as usize);
}
match self.make_error(ret) {
SslError::StreamError(e) => Err(e),
e => Err(io::Error::new(io::ErrorKind::Other, e)),
}
}
fn flush(&mut self) -> io::Result<()> {
self.stream.flush()
}
}
#[derive(Clone)] #[derive(Clone)]
enum StreamKind<S> { enum StreamKind<S> {
Indirect(IndirectStream<S>), Indirect(IndirectStream<S>),
Direct(DirectStream<S>),
} }
impl<S> StreamKind<S> { impl<S> StreamKind<S> {
fn stream(&self) -> &S { fn stream(&self) -> &S {
match *self { match *self {
StreamKind::Indirect(ref s) => &s.stream StreamKind::Indirect(ref s) => &s.stream,
StreamKind::Direct(ref s) => &s.stream,
} }
} }
fn mut_stream(&mut self) -> &mut S { fn mut_stream(&mut self) -> &mut S {
match *self { match *self {
StreamKind::Indirect(ref mut s) => &mut s.stream StreamKind::Indirect(ref mut s) => &mut s.stream,
StreamKind::Direct(ref mut s) => &mut s.stream,
} }
} }
fn ssl(&self) -> &Ssl { fn ssl(&self) -> &Ssl {
match *self { match *self {
StreamKind::Indirect(ref s) => &s.ssl StreamKind::Indirect(ref s) => &s.ssl,
StreamKind::Direct(ref s) => &s.ssl,
} }
} }
} }
@ -887,7 +1000,8 @@ impl SslStream<net::TcpStream> {
/// Create a new independently owned handle to the underlying socket. /// Create a new independently owned handle to the underlying socket.
pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> { pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> {
let kind = match self.kind { let kind = match self.kind {
StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())) StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())),
StreamKind::Direct(ref s) => StreamKind::Direct(try!(s.try_clone()))
}; };
Ok(SslStream { Ok(SslStream {
kind: kind kind: kind
@ -901,6 +1015,46 @@ impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
} }
} }
#[cfg(unix)]
impl<S: ::std::os::unix::io::AsRawFd> SslStream<S> {
pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let ssl = try!(ssl.into_ssl());
let fd = stream.as_raw_fd() as c_int;
let stream = try!(DirectStream::new_client(ssl, stream, fd));
Ok(SslStream {
kind: StreamKind::Direct(stream)
})
}
pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let ssl = try!(ssl.into_ssl());
let fd = stream.as_raw_fd() as c_int;
let stream = try!(DirectStream::new_server(ssl, stream, fd));
Ok(SslStream {
kind: StreamKind::Direct(stream)
})
}
}
#[cfg(windows)]
impl<S: ::std::os::windows::io::AsRawSocket> SslStream<S> {
pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let fd = stream.as_raw_socket() as c_int;
let stream = try!(DirectStream::new_client(ssl, stream, fd));
Ok(SslStream {
kind: StreamKind::Direct(stream)
})
}
pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let fd = stream.as_raw_socket() as c_int;
let stream = try!(DirectStream::new_server(ssl, stream, fd));
Ok(SslStream {
kind: StreamKind::Direct(stream)
})
}
}
impl<S: Read+Write> SslStream<S> { impl<S: Read+Write> SslStream<S> {
pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> { pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
let stream = try!(IndirectStream::new_client(ssl, stream)); let stream = try!(IndirectStream::new_client(ssl, stream));
@ -994,7 +1148,8 @@ impl<S: Read+Write> SslStream<S> {
impl<S: Read+Write> Read for SslStream<S> { impl<S: Read+Write> Read for SslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.kind { match self.kind {
StreamKind::Indirect(ref mut s) => s.read(buf) StreamKind::Indirect(ref mut s) => s.read(buf),
StreamKind::Direct(ref mut s) => s.read(buf),
} }
} }
} }
@ -1002,13 +1157,15 @@ impl<S: Read+Write> Read for SslStream<S> {
impl<S: Read+Write> Write for SslStream<S> { impl<S: Read+Write> Write for SslStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.kind { match self.kind {
StreamKind::Indirect(ref mut s) => s.write(buf) StreamKind::Indirect(ref mut s) => s.write(buf),
StreamKind::Direct(ref mut s) => s.write(buf),
} }
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
match self.kind { match self.kind {
StreamKind::Indirect(ref mut s) => s.flush() StreamKind::Indirect(ref mut s) => s.flush(),
StreamKind::Direct(ref mut s) => s.flush(),
} }
} }
} }

View File

@ -317,8 +317,17 @@ fn test_write() {
stream.flush().unwrap(); stream.flush().unwrap();
} }
#[test]
fn test_write_direct() {
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
stream.write_all("hello".as_bytes()).unwrap();
stream.flush().unwrap();
stream.write_all(" there".as_bytes()).unwrap();
stream.flush().unwrap();
}
run_test!(get_peer_certificate, |method, stream| { run_test!(get_peer_certificate, |method, stream| {
//let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
let stream = SslStream::new_client(&SslContext::new(method).unwrap(), stream).unwrap(); let stream = SslStream::new_client(&SslContext::new(method).unwrap(), stream).unwrap();
let cert = stream.get_peer_certificate().unwrap(); let cert = stream.get_peer_certificate().unwrap();
let fingerprint = cert.fingerprint(SHA256).unwrap(); let fingerprint = cert.fingerprint(SHA256).unwrap();
@ -349,6 +358,14 @@ fn test_read() {
io::copy(&mut stream, &mut io::sink()).ok().expect("read error"); io::copy(&mut stream, &mut io::sink()).ok().expect("read error");
} }
#[test]
fn test_read_direct() {
let tcp = TcpStream::connect("127.0.0.1:15418").unwrap();
let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), tcp).unwrap();
stream.write_all("GET /\r\n\r\n".as_bytes()).unwrap();
stream.flush().unwrap();
io::copy(&mut stream, &mut io::sink()).ok().expect("read error");
}
#[test] #[test]
fn test_pending() { fn test_pending() {