Implement direct IO support
This commit is contained in:
parent
9b235a7b91
commit
1373a76ce1
|
|
@ -603,11 +603,6 @@ impl Ssl {
|
|||
return Err(SslError::get());
|
||||
}
|
||||
let ssl = Ssl { ssl: ssl };
|
||||
|
||||
let rbio = try!(MemBio::new());
|
||||
let wbio = try!(MemBio::new());
|
||||
|
||||
unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) }
|
||||
Ok(ssl)
|
||||
}
|
||||
|
||||
|
|
@ -769,6 +764,12 @@ impl IndirectStream<net::TcpStream> {
|
|||
impl<S: Read+Write> IndirectStream<S> {
|
||||
fn new_base<T: IntoSsl>(ssl: T, stream: S) -> Result<IndirectStream<S>, SslError> {
|
||||
let ssl = try!(ssl.into_ssl());
|
||||
|
||||
let rbio = try!(MemBio::new());
|
||||
let wbio = try!(MemBio::new());
|
||||
|
||||
unsafe { ffi::SSL_set_bio(ssl.ssl, rbio.unwrap(), wbio.unwrap()) }
|
||||
|
||||
Ok(IndirectStream {
|
||||
stream: stream,
|
||||
ssl: Arc::new(ssl),
|
||||
|
|
@ -852,27 +853,139 @@ impl<S: Read+Write> Write for IndirectStream<S> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DirectStream<S> {
|
||||
stream: S,
|
||||
ssl: Arc<Ssl>,
|
||||
}
|
||||
|
||||
impl DirectStream<net::TcpStream> {
|
||||
fn try_clone(&self) -> io::Result<DirectStream<net::TcpStream>> {
|
||||
Ok(DirectStream {
|
||||
stream: try!(self.stream.try_clone()),
|
||||
ssl: self.ssl.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> DirectStream<S> {
|
||||
fn new_base(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
|
||||
unsafe {
|
||||
let bio = ffi::BIO_new_socket(sock, 0);
|
||||
if bio == ptr::null_mut() {
|
||||
return Err(SslError::get());
|
||||
}
|
||||
ffi::SSL_set_bio(ssl.ssl, bio, bio);
|
||||
}
|
||||
|
||||
Ok(DirectStream {
|
||||
stream: stream,
|
||||
ssl: Arc::new(ssl),
|
||||
})
|
||||
}
|
||||
|
||||
fn new_client(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
|
||||
let ssl = try!(DirectStream::new_base(ssl, stream, sock));
|
||||
let ret = ssl.ssl.connect();
|
||||
if ret > 0 {
|
||||
Ok(ssl)
|
||||
} else {
|
||||
Err(ssl.make_error(ret))
|
||||
}
|
||||
}
|
||||
|
||||
fn new_server(ssl: Ssl, stream: S, sock: c_int) -> Result<DirectStream<S>, SslError> {
|
||||
let ssl = try!(DirectStream::new_base(ssl, stream, sock));
|
||||
let ret = ssl.ssl.accept();
|
||||
if ret > 0 {
|
||||
Ok(ssl)
|
||||
} else {
|
||||
Err(ssl.make_error(ret))
|
||||
}
|
||||
}
|
||||
|
||||
fn make_error(&self, ret: c_int) -> SslError {
|
||||
match self.ssl.get_error(ret) {
|
||||
LibSslError::ErrorSsl => SslError::get(),
|
||||
LibSslError::ErrorSyscall => {
|
||||
let err = SslError::get();
|
||||
let count = match err {
|
||||
SslError::OpenSslErrors(ref v) => v.len(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
if count == 0 {
|
||||
if ret == 0 {
|
||||
SslError::StreamError(io::Error::new(io::ErrorKind::ConnectionAborted,
|
||||
"unexpected EOF observed"))
|
||||
} else {
|
||||
SslError::StreamError(io::Error::last_os_error())
|
||||
}
|
||||
} else {
|
||||
err
|
||||
}
|
||||
}
|
||||
err => panic!("unexpected error {:?} with ret {}", err, ret),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Read for DirectStream<S> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let ret = self.ssl.read(buf);
|
||||
if ret >= 0 {
|
||||
return Ok(ret as usize);
|
||||
}
|
||||
|
||||
match self.make_error(ret) {
|
||||
SslError::StreamError(e) => Err(e),
|
||||
e => Err(io::Error::new(io::ErrorKind::Other, e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Write> Write for DirectStream<S> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
let ret = self.ssl.write(buf);
|
||||
if ret > 0 {
|
||||
return Ok(ret as usize);
|
||||
}
|
||||
|
||||
match self.make_error(ret) {
|
||||
SslError::StreamError(e) => Err(e),
|
||||
e => Err(io::Error::new(io::ErrorKind::Other, e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.stream.flush()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum StreamKind<S> {
|
||||
Indirect(IndirectStream<S>),
|
||||
Direct(DirectStream<S>),
|
||||
}
|
||||
|
||||
impl<S> StreamKind<S> {
|
||||
fn stream(&self) -> &S {
|
||||
match *self {
|
||||
StreamKind::Indirect(ref s) => &s.stream
|
||||
StreamKind::Indirect(ref s) => &s.stream,
|
||||
StreamKind::Direct(ref s) => &s.stream,
|
||||
}
|
||||
}
|
||||
|
||||
fn mut_stream(&mut self) -> &mut S {
|
||||
match *self {
|
||||
StreamKind::Indirect(ref mut s) => &mut s.stream
|
||||
StreamKind::Indirect(ref mut s) => &mut s.stream,
|
||||
StreamKind::Direct(ref mut s) => &mut s.stream,
|
||||
}
|
||||
}
|
||||
|
||||
fn ssl(&self) -> &Ssl {
|
||||
match *self {
|
||||
StreamKind::Indirect(ref s) => &s.ssl
|
||||
StreamKind::Indirect(ref s) => &s.ssl,
|
||||
StreamKind::Direct(ref s) => &s.ssl,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -887,7 +1000,8 @@ impl SslStream<net::TcpStream> {
|
|||
/// Create a new independently owned handle to the underlying socket.
|
||||
pub fn try_clone(&self) -> io::Result<SslStream<net::TcpStream>> {
|
||||
let kind = match self.kind {
|
||||
StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone()))
|
||||
StreamKind::Indirect(ref s) => StreamKind::Indirect(try!(s.try_clone())),
|
||||
StreamKind::Direct(ref s) => StreamKind::Direct(try!(s.try_clone()))
|
||||
};
|
||||
Ok(SslStream {
|
||||
kind: kind
|
||||
|
|
@ -901,6 +1015,46 @@ impl<S> fmt::Debug for SslStream<S> where S: fmt::Debug {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl<S: ::std::os::unix::io::AsRawFd> SslStream<S> {
|
||||
pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
let ssl = try!(ssl.into_ssl());
|
||||
let fd = stream.as_raw_fd() as c_int;
|
||||
let stream = try!(DirectStream::new_client(ssl, stream, fd));
|
||||
Ok(SslStream {
|
||||
kind: StreamKind::Direct(stream)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
let ssl = try!(ssl.into_ssl());
|
||||
let fd = stream.as_raw_fd() as c_int;
|
||||
let stream = try!(DirectStream::new_server(ssl, stream, fd));
|
||||
Ok(SslStream {
|
||||
kind: StreamKind::Direct(stream)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
impl<S: ::std::os::windows::io::AsRawSocket> SslStream<S> {
|
||||
pub fn new_client_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
let fd = stream.as_raw_socket() as c_int;
|
||||
let stream = try!(DirectStream::new_client(ssl, stream, fd));
|
||||
Ok(SslStream {
|
||||
kind: StreamKind::Direct(stream)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_server_direct<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
let fd = stream.as_raw_socket() as c_int;
|
||||
let stream = try!(DirectStream::new_server(ssl, stream, fd));
|
||||
Ok(SslStream {
|
||||
kind: StreamKind::Direct(stream)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Read+Write> SslStream<S> {
|
||||
pub fn new_client<T: IntoSsl>(ssl: T, stream: S) -> Result<SslStream<S>, SslError> {
|
||||
let stream = try!(IndirectStream::new_client(ssl, stream));
|
||||
|
|
@ -994,7 +1148,8 @@ impl<S: Read+Write> SslStream<S> {
|
|||
impl<S: Read+Write> Read for SslStream<S> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self.kind {
|
||||
StreamKind::Indirect(ref mut s) => s.read(buf)
|
||||
StreamKind::Indirect(ref mut s) => s.read(buf),
|
||||
StreamKind::Direct(ref mut s) => s.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1002,13 +1157,15 @@ impl<S: Read+Write> Read for SslStream<S> {
|
|||
impl<S: Read+Write> Write for SslStream<S> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self.kind {
|
||||
StreamKind::Indirect(ref mut s) => s.write(buf)
|
||||
StreamKind::Indirect(ref mut s) => s.write(buf),
|
||||
StreamKind::Direct(ref mut s) => s.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self.kind {
|
||||
StreamKind::Indirect(ref mut s) => s.flush()
|
||||
StreamKind::Indirect(ref mut s) => s.flush(),
|
||||
StreamKind::Direct(ref mut s) => s.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -317,8 +317,17 @@ fn test_write() {
|
|||
stream.flush().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_direct() {
|
||||
let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
|
||||
let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
|
||||
stream.write_all("hello".as_bytes()).unwrap();
|
||||
stream.flush().unwrap();
|
||||
stream.write_all(" there".as_bytes()).unwrap();
|
||||
stream.flush().unwrap();
|
||||
}
|
||||
|
||||
run_test!(get_peer_certificate, |method, stream| {
|
||||
//let stream = TcpStream::connect("127.0.0.1:15418").unwrap();
|
||||
let stream = SslStream::new_client(&SslContext::new(method).unwrap(), stream).unwrap();
|
||||
let cert = stream.get_peer_certificate().unwrap();
|
||||
let fingerprint = cert.fingerprint(SHA256).unwrap();
|
||||
|
|
@ -349,6 +358,14 @@ fn test_read() {
|
|||
io::copy(&mut stream, &mut io::sink()).ok().expect("read error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_direct() {
|
||||
let tcp = TcpStream::connect("127.0.0.1:15418").unwrap();
|
||||
let mut stream = SslStream::new_client_direct(&SslContext::new(Sslv23).unwrap(), tcp).unwrap();
|
||||
stream.write_all("GET /\r\n\r\n".as_bytes()).unwrap();
|
||||
stream.flush().unwrap();
|
||||
io::copy(&mut stream, &mut io::sink()).ok().expect("read error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pending() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue