Introduce AsyncStreamBridge

This encapsulates a bit better the unsafety of task context
management to invoke async code from inside boring.
This commit is contained in:
Anthony Ramine 2023-08-04 13:09:24 +02:00 committed by Alessandro Ghedini
parent 97e2a8bc30
commit 1c790f7277
2 changed files with 143 additions and 108 deletions

View File

@ -0,0 +1,86 @@
//! Bridge between sync IO traits and async tokio IO traits.
use std::fmt;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pub(crate) struct AsyncStreamBridge<S> {
pub(crate) stream: S,
waker: Option<Waker>,
}
impl<S> AsyncStreamBridge<S> {
pub(crate) fn new(stream: S) -> Self
where
S: AsyncRead + AsyncWrite + Unpin,
{
Self {
stream,
waker: None,
}
}
pub(crate) fn set_waker(&mut self, ctx: Option<&mut Context<'_>>) {
self.waker = ctx.map(|ctx| ctx.waker().clone())
}
/// # Panics
///
/// Panics if the bridge has no waker.
pub(crate) fn with_context<F, R>(&mut self, f: F) -> R
where
S: Unpin,
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
{
let mut ctx =
Context::from_waker(self.waker.as_ref().expect("missing task context pointer"));
f(&mut ctx, Pin::new(&mut self.stream))
}
}
impl<S> io::Read for AsyncStreamBridge<S>
where
S: AsyncRead + Unpin,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.with_context(|ctx, stream| {
let mut buf = ReadBuf::new(buf);
match stream.poll_read(ctx, &mut buf)? {
Poll::Ready(()) => Ok(buf.filled().len()),
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
})
}
}
impl<S> io::Write for AsyncStreamBridge<S>
where
S: AsyncWrite + Unpin,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
fn flush(&mut self) -> io::Result<()> {
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}
impl<S> fmt::Debug for AsyncStreamBridge<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.stream, fmt)
}
}

View File

@ -27,6 +27,10 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
mod bridge;
use self::bridge::AsyncStreamBridge;
/// Asynchronously performs a client-side TLS handshake over the provided stream.
pub async fn connect<S>(
config: ConnectConfiguration,
@ -48,94 +52,18 @@ where
}
async fn handshake<S>(
f: impl FnOnce(StreamWrapper<S>) -> Result<MidHandshakeSslStream<StreamWrapper<S>>, ErrorStack>,
f: impl FnOnce(
AsyncStreamBridge<S>,
) -> Result<MidHandshakeSslStream<AsyncStreamBridge<S>>, ErrorStack>,
stream: S,
) -> Result<SslStream<S>, HandshakeError<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let ongoing_handshake = Some(
f(StreamWrapper { stream, context: 0 })
.map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?,
);
let mid_handshake = f(AsyncStreamBridge::new(stream))
.map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
HandshakeFuture(ongoing_handshake).await
}
struct StreamWrapper<S> {
stream: S,
context: usize,
}
impl<S> StreamWrapper<S> {
/// # Safety
///
/// Must be called with `context` set to a valid pointer to a live `Context` object, and the
/// wrapper must be pinned in memory.
unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
debug_assert_ne!(self.context, 0);
let stream = Pin::new_unchecked(&mut self.stream);
let context = &mut *(self.context as *mut _);
(stream, context)
}
}
impl<S> fmt::Debug for StreamWrapper<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.stream, fmt)
}
}
impl<S> StreamWrapper<S>
where
S: Unpin,
{
fn with_context<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
{
unsafe {
assert_ne!(self.context, 0);
let waker = &mut *(self.context as *mut _);
f(waker, Pin::new(&mut self.stream))
}
}
}
impl<S> Read for StreamWrapper<S>
where
S: AsyncRead + Unpin,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let (stream, cx) = unsafe { self.parts() };
let mut buf = ReadBuf::new(buf);
match stream.poll_read(cx, &mut buf)? {
Poll::Ready(()) => Ok(buf.filled().len()),
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
}
impl<S> Write for StreamWrapper<S>
where
S: AsyncWrite + Unpin,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
fn flush(&mut self) -> io::Result<()> {
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) {
Poll::Ready(r) => r,
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
}
}
HandshakeFuture(Some(mid_handshake)).await
}
fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
@ -154,7 +82,7 @@ fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
/// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written
/// to a `SslStream` are encrypted when passing through to `S`.
#[derive(Debug)]
pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
impl<S> SslStream<S> {
/// Returns a shared reference to the `Ssl` object associated with this stream.
@ -172,14 +100,21 @@ impl<S> SslStream<S> {
&mut self.0.get_mut().stream
}
fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
where
F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
{
self.0.get_mut().context = ctx as *mut _ as usize;
let r = f(&mut self.0);
self.0.get_mut().context = 0;
r
self.0.get_mut().set_waker(Some(ctx));
let result = f(&mut self.0);
// NOTE(nox): This should also be executed when `f` panics,
// but it's not that important as boring segfaults on panics
// and we always set the context prior to doing anything with
// the inner async stream.
self.0.get_mut().set_waker(None);
result
}
}
@ -195,8 +130,10 @@ where
///
/// The caller must ensure the pointer is valid.
pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self {
let stream = StreamWrapper { stream, context: 0 };
SslStream(ssl::SslStream::from_raw_parts(ssl, stream))
Self(ssl::SslStream::from_raw_parts(
ssl,
AsyncStreamBridge::new(stream),
))
}
}
@ -209,7 +146,7 @@ where
ctx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
self.with_context(ctx, |s| {
self.run_in_context(ctx, |s| {
// This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though
// OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now.
let slice = unsafe {
@ -239,15 +176,15 @@ where
ctx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.with_context(ctx, |s| cvt(s.write(buf)))
self.run_in_context(ctx, |s| cvt(s.write(buf)))
}
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
self.with_context(ctx, |s| cvt(s.flush()))
self.run_in_context(ctx, |s| cvt(s.flush()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
match self.with_context(ctx, |s| s.shutdown()) {
match self.run_in_context(ctx, |s| s.shutdown()) {
Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
@ -265,7 +202,7 @@ where
}
/// The error type returned after a failed handshake.
pub struct HandshakeError<S>(ssl::HandshakeError<StreamWrapper<S>>);
pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
impl<S> HandshakeError<S> {
/// Returns a shared reference to the `Ssl` object associated with this error.
@ -333,7 +270,10 @@ where
}
}
struct HandshakeFuture<S>(Option<MidHandshakeSslStream<StreamWrapper<S>>>);
/// Future for an ongoing TLS handshake.
///
/// See [`connect`] and [`accept`].
pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
impl<S> Future for HandshakeFuture<S>
where
@ -341,25 +281,34 @@ where
{
type Output = Result<SslStream<S>, HandshakeError<S>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut s = self.0.take().expect("future polled after completion");
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let mut mid_handshake = self.0.take().expect("future polled after completion");
s.get_mut().context = cx as *mut _ as usize;
mid_handshake.get_mut().set_waker(Some(ctx));
match s.handshake() {
Ok(mut s) => {
s.get_mut().context = 0;
match mid_handshake.handshake() {
Ok(mut stream) => {
stream.get_mut().set_waker(None);
Poll::Ready(Ok(SslStream(s)))
Poll::Ready(Ok(SslStream(stream)))
}
Err(ssl::HandshakeError::WouldBlock(mut s)) => {
s.get_mut().context = 0;
Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
mid_handshake.get_mut().set_waker(None);
self.0 = Some(s);
self.0 = Some(mid_handshake);
Poll::Pending
}
Err(e) => Poll::Ready(Err(HandshakeError(e))),
Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
mid_handshake.get_mut().set_waker(None);
Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
mid_handshake,
))))
}
Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
Poll::Ready(Err(HandshakeError(err)))
}
}
}
}