diff --git a/hyper-boring/Cargo.toml b/hyper-boring/Cargo.toml index dd704209..b58d1318 100644 --- a/hyper-boring/Cargo.toml +++ b/hyper-boring/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hyper-boring" -version = "1.0.4" +version = "2.0.0" authors = ["Steven Fackler ", "Ivan Nikulin "] edition = "2018" description = "Hyper TLS support via BoringSSL" @@ -19,16 +19,16 @@ runtime = ["hyper/runtime"] antidote = "1.0.0" bytes = "0.5" http = "0.2" -hyper = { version = "0.13", default-features = false } +hyper = { version = "0.14", default-features = false, features = ["client"] } linked_hash_set = "0.1" once_cell = "1.0" boring = { version = "1.0.2", path = "../boring" } boring-sys = { version = "1.0.2", path = "../boring-sys" } -tokio = "0.2" -tokio-boring = { version = "1.0.2", path = "../tokio-boring" } +tokio = "1" +tokio-boring = { version = "2", path = "../tokio-boring" } tower-layer = "0.3" [dev-dependencies] -hyper = "0.13" -tokio = { version = "0.2", features = ["full"] } +hyper = { version = "0.14", features = ["full"] } +tokio = { version = "1", features = ["full"] } futures = "0.3" diff --git a/hyper-boring/src/lib.rs b/hyper-boring/src/lib.rs index 155dc4f2..2819192f 100644 --- a/hyper-boring/src/lib.rs +++ b/hyper-boring/src/lib.rs @@ -8,7 +8,6 @@ use boring::ex_data::Index; use boring::ssl::{ ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, SslSessionCacheMode, }; -use bytes::{Buf, BufMut}; use http::uri::Scheme; use hyper::client::connect::{Connected, Connection}; #[cfg(feature = "runtime")] @@ -20,12 +19,11 @@ use std::error::Error; use std::fmt::Debug; use std::future::Future; use std::io; -use std::mem::MaybeUninit; use std::net; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_boring::SslStream; use tower_layer::Layer; @@ -267,37 +265,16 @@ impl AsyncRead for MaybeHttpsStream where T: AsyncRead + AsyncWrite + Unpin, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - match &*self { - MaybeHttpsStream::Http(s) => s.prepare_uninitialized_buffer(buf), - MaybeHttpsStream::Https(s) => s.prepare_uninitialized_buffer(buf), - } - } - fn poll_read( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { match &mut *self { MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(ctx, buf), MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(ctx, buf), } } - - fn poll_read_buf( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - buf: &mut B, - ) -> Poll> - where - B: BufMut, - { - match &mut *self { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_read_buf(ctx, buf), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_read_buf(ctx, buf), - } - } } impl AsyncWrite for MaybeHttpsStream @@ -328,20 +305,6 @@ where MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(ctx), } } - - fn poll_write_buf( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - buf: &mut B, - ) -> Poll> - where - B: Buf, - { - match &mut *self { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_write_buf(ctx, buf), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_write_buf(ctx, buf), - } - } } impl Connection for MaybeHttpsStream diff --git a/hyper-boring/src/test.rs b/hyper-boring/src/test.rs index 8d431cbe..39b96e34 100644 --- a/hyper-boring/src/test.rs +++ b/hyper-boring/src/test.rs @@ -28,7 +28,7 @@ async fn google() { #[tokio::test] async fn localhost() { - let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let port = listener.local_addr().unwrap().port(); let server = async move { @@ -89,7 +89,7 @@ async fn localhost() { async fn alpn_h2() { use boring::ssl::{self, AlpnError}; - let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let port = listener.local_addr().unwrap().port(); let server = async move { diff --git a/tokio-boring/Cargo.toml b/tokio-boring/Cargo.toml index dba07eb5..669051ef 100644 --- a/tokio-boring/Cargo.toml +++ b/tokio-boring/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-boring" -version = "1.0.3" +version = "2.0.0" authors = ["Alex Crichton ", "Ivan Nikulin "] license = "MIT/Apache-2.0" edition = "2018" @@ -14,8 +14,8 @@ An implementation of SSL streams for Tokio backed by BoringSSL [dependencies] boring = { version = "1.0.3", path = "../boring" } boring-sys = { version = "1.0.2", path = "../boring-sys" } -tokio = "0.2" +tokio = "1" [dev-dependencies] futures = "0.3" -tokio = { version = "0.2", features = ["full"] } +tokio = { version = "1", features = ["full"] } diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index dbda8be1..750e1ee1 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -22,10 +22,9 @@ use std::error::Error; use std::fmt; use std::future::Future; use std::io::{self, Read, Write}; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; /// Asynchronously performs a client-side TLS handshake over the provided stream. pub async fn connect( @@ -69,6 +68,19 @@ struct StreamWrapper { context: usize, } +impl StreamWrapper { + /// # Safety + /// + /// Must be called with `context` set to a valid pointer to a live `Context` object, and the + /// wrapper must be pinned in memory. + unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) { + debug_assert_ne!(self.context, 0); + let stream = Pin::new_unchecked(&mut self.stream); + let context = &mut *(self.context as *mut _); + (stream, context) + } +} + impl fmt::Debug for StreamWrapper where S: fmt::Debug, @@ -99,8 +111,10 @@ where S: AsyncRead + Unpin, { fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self.with_context(|ctx, stream| stream.poll_read(ctx, buf)) { - Poll::Ready(r) => r, + let (stream, cx) = unsafe { self.parts() }; + let mut buf = ReadBuf::new(buf); + match stream.poll_read(cx, &mut buf)? { + Poll::Ready(()) => Ok(buf.filled().len()), Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), } } @@ -191,19 +205,29 @@ impl AsyncRead for SslStream where S: AsyncRead + AsyncWrite + Unpin, { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { - // Note that this does not forward to `S` because the buffer is - // unconditionally filled in by OpenSSL, not the actual object `S`. - // We're decrypting bytes from `S` into the buffer above! - false - } - fn poll_read( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - self.with_context(ctx, |s| cvt(s.read(buf))) + buf: &mut ReadBuf, + ) -> Poll> { + self.with_context(ctx, |s| { + // This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though + // OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now. + let slice = unsafe { + let buf = buf.unfilled_mut(); + std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast::(), buf.len()) + }; + match cvt(s.read(slice))? { + Poll::Ready(nread) => { + unsafe { + buf.assume_init(nread); + } + buf.advance(nread); + Poll::Ready(Ok(())) + } + Poll::Pending => Poll::Pending, + } + }) } } diff --git a/tokio-boring/tests/google.rs b/tokio-boring/tests/google.rs index d408d0ec..72c5a040 100644 --- a/tokio-boring/tests/google.rs +++ b/tokio-boring/tests/google.rs @@ -38,7 +38,10 @@ fn create_server() -> ( SocketAddr, ) { let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - let mut listener = TcpListener::from_std(listener).unwrap(); + + listener.set_nonblocking(true).unwrap(); + + let listener = TcpListener::from_std(listener).unwrap(); let addr = listener.local_addr().unwrap(); let server = async move {