diff --git a/tokio-boring/src/bridge.rs b/tokio-boring/src/bridge.rs new file mode 100644 index 00000000..9de3fc22 --- /dev/null +++ b/tokio-boring/src/bridge.rs @@ -0,0 +1,86 @@ +//! Bridge between sync IO traits and async tokio IO traits. + +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub(crate) struct AsyncStreamBridge { + pub(crate) stream: S, + waker: Option, +} + +impl AsyncStreamBridge { + pub(crate) fn new(stream: S) -> Self + where + S: AsyncRead + AsyncWrite + Unpin, + { + Self { + stream, + waker: None, + } + } + + pub(crate) fn set_waker(&mut self, ctx: Option<&mut Context<'_>>) { + self.waker = ctx.map(|ctx| ctx.waker().clone()) + } + + /// # Panics + /// + /// Panics if the bridge has no waker. + pub(crate) fn with_context(&mut self, f: F) -> R + where + S: Unpin, + F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, + { + let mut ctx = + Context::from_waker(self.waker.as_ref().expect("missing task context pointer")); + + f(&mut ctx, Pin::new(&mut self.stream)) + } +} + +impl io::Read for AsyncStreamBridge +where + S: AsyncRead + Unpin, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.with_context(|ctx, stream| { + let mut buf = ReadBuf::new(buf); + + match stream.poll_read(ctx, &mut buf)? { + Poll::Ready(()) => Ok(buf.filled().len()), + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + }) + } +} + +impl io::Write for AsyncStreamBridge +where + S: AsyncWrite + Unpin, +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) { + Poll::Ready(r) => r, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + } + + fn flush(&mut self) -> io::Result<()> { + match self.with_context(|ctx, stream| stream.poll_flush(ctx)) { + Poll::Ready(r) => r, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + } +} + +impl fmt::Debug for AsyncStreamBridge +where + S: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.stream, fmt) + } +} diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index f594231d..f437ee26 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -27,6 +27,10 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +mod bridge; + +use self::bridge::AsyncStreamBridge; + /// Asynchronously performs a client-side TLS handshake over the provided stream. pub async fn connect( config: ConnectConfiguration, @@ -48,94 +52,18 @@ where } async fn handshake( - f: impl FnOnce(StreamWrapper) -> Result>, ErrorStack>, + f: impl FnOnce( + AsyncStreamBridge, + ) -> Result>, ErrorStack>, stream: S, ) -> Result, HandshakeError> where S: AsyncRead + AsyncWrite + Unpin, { - let ongoing_handshake = Some( - f(StreamWrapper { stream, context: 0 }) - .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?, - ); + let mid_handshake = f(AsyncStreamBridge::new(stream)) + .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?; - HandshakeFuture(ongoing_handshake).await -} - -struct StreamWrapper { - stream: S, - context: usize, -} - -impl StreamWrapper { - /// # Safety - /// - /// Must be called with `context` set to a valid pointer to a live `Context` object, and the - /// wrapper must be pinned in memory. - unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) { - debug_assert_ne!(self.context, 0); - let stream = Pin::new_unchecked(&mut self.stream); - let context = &mut *(self.context as *mut _); - (stream, context) - } -} - -impl fmt::Debug for StreamWrapper -where - S: fmt::Debug, -{ - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Debug::fmt(&self.stream, fmt) - } -} - -impl StreamWrapper -where - S: Unpin, -{ - fn with_context(&mut self, f: F) -> R - where - F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, - { - unsafe { - assert_ne!(self.context, 0); - let waker = &mut *(self.context as *mut _); - f(waker, Pin::new(&mut self.stream)) - } - } -} - -impl Read for StreamWrapper -where - S: AsyncRead + Unpin, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let (stream, cx) = unsafe { self.parts() }; - let mut buf = ReadBuf::new(buf); - match stream.poll_read(cx, &mut buf)? { - Poll::Ready(()) => Ok(buf.filled().len()), - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } -} - -impl Write for StreamWrapper -where - S: AsyncWrite + Unpin, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self.with_context(|ctx, stream| stream.poll_flush(ctx)) { - Poll::Ready(r) => r, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), - } - } + HandshakeFuture(Some(mid_handshake)).await } fn cvt(r: io::Result) -> Poll> { @@ -154,7 +82,7 @@ fn cvt(r: io::Result) -> Poll> { /// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written /// to a `SslStream` are encrypted when passing through to `S`. #[derive(Debug)] -pub struct SslStream(ssl::SslStream>); +pub struct SslStream(ssl::SslStream>); impl SslStream { /// Returns a shared reference to the `Ssl` object associated with this stream. @@ -172,14 +100,21 @@ impl SslStream { &mut self.0.get_mut().stream } - fn with_context(&mut self, ctx: &mut Context<'_>, f: F) -> R + fn run_in_context(&mut self, ctx: &mut Context<'_>, f: F) -> R where - F: FnOnce(&mut ssl::SslStream>) -> R, + F: FnOnce(&mut ssl::SslStream>) -> R, { - self.0.get_mut().context = ctx as *mut _ as usize; - let r = f(&mut self.0); - self.0.get_mut().context = 0; - r + self.0.get_mut().set_waker(Some(ctx)); + + let result = f(&mut self.0); + + // NOTE(nox): This should also be executed when `f` panics, + // but it's not that important as boring segfaults on panics + // and we always set the context prior to doing anything with + // the inner async stream. + self.0.get_mut().set_waker(None); + + result } } @@ -195,8 +130,10 @@ where /// /// The caller must ensure the pointer is valid. pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self { - let stream = StreamWrapper { stream, context: 0 }; - SslStream(ssl::SslStream::from_raw_parts(ssl, stream)) + Self(ssl::SslStream::from_raw_parts( + ssl, + AsyncStreamBridge::new(stream), + )) } } @@ -209,7 +146,7 @@ where ctx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll> { - self.with_context(ctx, |s| { + self.run_in_context(ctx, |s| { // This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though // OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now. let slice = unsafe { @@ -239,15 +176,15 @@ where ctx: &mut Context, buf: &[u8], ) -> Poll> { - self.with_context(ctx, |s| cvt(s.write(buf))) + self.run_in_context(ctx, |s| cvt(s.write(buf))) } fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { - self.with_context(ctx, |s| cvt(s.flush())) + self.run_in_context(ctx, |s| cvt(s.flush())) } fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { - match self.with_context(ctx, |s| s.shutdown()) { + match self.run_in_context(ctx, |s| s.shutdown()) { Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {} Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {} Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => { @@ -265,7 +202,7 @@ where } /// The error type returned after a failed handshake. -pub struct HandshakeError(ssl::HandshakeError>); +pub struct HandshakeError(ssl::HandshakeError>); impl HandshakeError { /// Returns a shared reference to the `Ssl` object associated with this error. @@ -333,7 +270,10 @@ where } } -struct HandshakeFuture(Option>>); +/// Future for an ongoing TLS handshake. +/// +/// See [`connect`] and [`accept`]. +pub struct HandshakeFuture(Option>>); impl Future for HandshakeFuture where @@ -341,25 +281,34 @@ where { type Output = Result, HandshakeError>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut s = self.0.take().expect("future polled after completion"); + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + let mut mid_handshake = self.0.take().expect("future polled after completion"); - s.get_mut().context = cx as *mut _ as usize; + mid_handshake.get_mut().set_waker(Some(ctx)); - match s.handshake() { - Ok(mut s) => { - s.get_mut().context = 0; + match mid_handshake.handshake() { + Ok(mut stream) => { + stream.get_mut().set_waker(None); - Poll::Ready(Ok(SslStream(s))) + Poll::Ready(Ok(SslStream(stream))) } - Err(ssl::HandshakeError::WouldBlock(mut s)) => { - s.get_mut().context = 0; + Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => { + mid_handshake.get_mut().set_waker(None); - self.0 = Some(s); + self.0 = Some(mid_handshake); Poll::Pending } - Err(e) => Poll::Ready(Err(HandshakeError(e))), + Err(ssl::HandshakeError::Failure(mut mid_handshake)) => { + mid_handshake.get_mut().set_waker(None); + + Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure( + mid_handshake, + )))) + } + Err(err @ ssl::HandshakeError::SetupFailure(_)) => { + Poll::Ready(Err(HandshakeError(err))) + } } } }