Introduce AsyncStreamBridge
This encapsulates a bit better the unsafety of task context management to invoke async code from inside boring.
This commit is contained in:
parent
97e2a8bc30
commit
1c790f7277
|
|
@ -0,0 +1,86 @@
|
|||
//! Bridge between sync IO traits and async tokio IO traits.
|
||||
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
pub(crate) struct AsyncStreamBridge<S> {
|
||||
pub(crate) stream: S,
|
||||
waker: Option<Waker>,
|
||||
}
|
||||
|
||||
impl<S> AsyncStreamBridge<S> {
|
||||
pub(crate) fn new(stream: S) -> Self
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
Self {
|
||||
stream,
|
||||
waker: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_waker(&mut self, ctx: Option<&mut Context<'_>>) {
|
||||
self.waker = ctx.map(|ctx| ctx.waker().clone())
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the bridge has no waker.
|
||||
pub(crate) fn with_context<F, R>(&mut self, f: F) -> R
|
||||
where
|
||||
S: Unpin,
|
||||
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
|
||||
{
|
||||
let mut ctx =
|
||||
Context::from_waker(self.waker.as_ref().expect("missing task context pointer"));
|
||||
|
||||
f(&mut ctx, Pin::new(&mut self.stream))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> io::Read for AsyncStreamBridge<S>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.with_context(|ctx, stream| {
|
||||
let mut buf = ReadBuf::new(buf);
|
||||
|
||||
match stream.poll_read(ctx, &mut buf)? {
|
||||
Poll::Ready(()) => Ok(buf.filled().len()),
|
||||
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> io::Write for AsyncStreamBridge<S>
|
||||
where
|
||||
S: AsyncWrite + Unpin,
|
||||
{
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) {
|
||||
Poll::Ready(r) => r,
|
||||
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) {
|
||||
Poll::Ready(r) => r,
|
||||
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> fmt::Debug for AsyncStreamBridge<S>
|
||||
where
|
||||
S: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt::Debug::fmt(&self.stream, fmt)
|
||||
}
|
||||
}
|
||||
|
|
@ -27,6 +27,10 @@ use std::pin::Pin;
|
|||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
mod bridge;
|
||||
|
||||
use self::bridge::AsyncStreamBridge;
|
||||
|
||||
/// Asynchronously performs a client-side TLS handshake over the provided stream.
|
||||
pub async fn connect<S>(
|
||||
config: ConnectConfiguration,
|
||||
|
|
@ -48,94 +52,18 @@ where
|
|||
}
|
||||
|
||||
async fn handshake<S>(
|
||||
f: impl FnOnce(StreamWrapper<S>) -> Result<MidHandshakeSslStream<StreamWrapper<S>>, ErrorStack>,
|
||||
f: impl FnOnce(
|
||||
AsyncStreamBridge<S>,
|
||||
) -> Result<MidHandshakeSslStream<AsyncStreamBridge<S>>, ErrorStack>,
|
||||
stream: S,
|
||||
) -> Result<SslStream<S>, HandshakeError<S>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let ongoing_handshake = Some(
|
||||
f(StreamWrapper { stream, context: 0 })
|
||||
.map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?,
|
||||
);
|
||||
let mid_handshake = f(AsyncStreamBridge::new(stream))
|
||||
.map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
|
||||
|
||||
HandshakeFuture(ongoing_handshake).await
|
||||
}
|
||||
|
||||
struct StreamWrapper<S> {
|
||||
stream: S,
|
||||
context: usize,
|
||||
}
|
||||
|
||||
impl<S> StreamWrapper<S> {
|
||||
/// # Safety
|
||||
///
|
||||
/// Must be called with `context` set to a valid pointer to a live `Context` object, and the
|
||||
/// wrapper must be pinned in memory.
|
||||
unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
|
||||
debug_assert_ne!(self.context, 0);
|
||||
let stream = Pin::new_unchecked(&mut self.stream);
|
||||
let context = &mut *(self.context as *mut _);
|
||||
(stream, context)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> fmt::Debug for StreamWrapper<S>
|
||||
where
|
||||
S: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt::Debug::fmt(&self.stream, fmt)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> StreamWrapper<S>
|
||||
where
|
||||
S: Unpin,
|
||||
{
|
||||
fn with_context<F, R>(&mut self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
|
||||
{
|
||||
unsafe {
|
||||
assert_ne!(self.context, 0);
|
||||
let waker = &mut *(self.context as *mut _);
|
||||
f(waker, Pin::new(&mut self.stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Read for StreamWrapper<S>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let (stream, cx) = unsafe { self.parts() };
|
||||
let mut buf = ReadBuf::new(buf);
|
||||
match stream.poll_read(cx, &mut buf)? {
|
||||
Poll::Ready(()) => Ok(buf.filled().len()),
|
||||
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Write for StreamWrapper<S>
|
||||
where
|
||||
S: AsyncWrite + Unpin,
|
||||
{
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) {
|
||||
Poll::Ready(r) => r,
|
||||
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) {
|
||||
Poll::Ready(r) => r,
|
||||
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||
}
|
||||
}
|
||||
HandshakeFuture(Some(mid_handshake)).await
|
||||
}
|
||||
|
||||
fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
|
||||
|
|
@ -154,7 +82,7 @@ fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
|
|||
/// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written
|
||||
/// to a `SslStream` are encrypted when passing through to `S`.
|
||||
#[derive(Debug)]
|
||||
pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
|
||||
pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
|
||||
|
||||
impl<S> SslStream<S> {
|
||||
/// Returns a shared reference to the `Ssl` object associated with this stream.
|
||||
|
|
@ -172,14 +100,21 @@ impl<S> SslStream<S> {
|
|||
&mut self.0.get_mut().stream
|
||||
}
|
||||
|
||||
fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
|
||||
fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
|
||||
F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
|
||||
{
|
||||
self.0.get_mut().context = ctx as *mut _ as usize;
|
||||
let r = f(&mut self.0);
|
||||
self.0.get_mut().context = 0;
|
||||
r
|
||||
self.0.get_mut().set_waker(Some(ctx));
|
||||
|
||||
let result = f(&mut self.0);
|
||||
|
||||
// NOTE(nox): This should also be executed when `f` panics,
|
||||
// but it's not that important as boring segfaults on panics
|
||||
// and we always set the context prior to doing anything with
|
||||
// the inner async stream.
|
||||
self.0.get_mut().set_waker(None);
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -195,8 +130,10 @@ where
|
|||
///
|
||||
/// The caller must ensure the pointer is valid.
|
||||
pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self {
|
||||
let stream = StreamWrapper { stream, context: 0 };
|
||||
SslStream(ssl::SslStream::from_raw_parts(ssl, stream))
|
||||
Self(ssl::SslStream::from_raw_parts(
|
||||
ssl,
|
||||
AsyncStreamBridge::new(stream),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -209,7 +146,7 @@ where
|
|||
ctx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.with_context(ctx, |s| {
|
||||
self.run_in_context(ctx, |s| {
|
||||
// This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though
|
||||
// OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now.
|
||||
let slice = unsafe {
|
||||
|
|
@ -239,15 +176,15 @@ where
|
|||
ctx: &mut Context,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.with_context(ctx, |s| cvt(s.write(buf)))
|
||||
self.run_in_context(ctx, |s| cvt(s.write(buf)))
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
|
||||
self.with_context(ctx, |s| cvt(s.flush()))
|
||||
self.run_in_context(ctx, |s| cvt(s.flush()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
|
||||
match self.with_context(ctx, |s| s.shutdown()) {
|
||||
match self.run_in_context(ctx, |s| s.shutdown()) {
|
||||
Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
|
||||
Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
|
||||
Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
|
||||
|
|
@ -265,7 +202,7 @@ where
|
|||
}
|
||||
|
||||
/// The error type returned after a failed handshake.
|
||||
pub struct HandshakeError<S>(ssl::HandshakeError<StreamWrapper<S>>);
|
||||
pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
|
||||
|
||||
impl<S> HandshakeError<S> {
|
||||
/// Returns a shared reference to the `Ssl` object associated with this error.
|
||||
|
|
@ -333,7 +270,10 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
struct HandshakeFuture<S>(Option<MidHandshakeSslStream<StreamWrapper<S>>>);
|
||||
/// Future for an ongoing TLS handshake.
|
||||
///
|
||||
/// See [`connect`] and [`accept`].
|
||||
pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
|
||||
|
||||
impl<S> Future for HandshakeFuture<S>
|
||||
where
|
||||
|
|
@ -341,25 +281,34 @@ where
|
|||
{
|
||||
type Output = Result<SslStream<S>, HandshakeError<S>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut s = self.0.take().expect("future polled after completion");
|
||||
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut mid_handshake = self.0.take().expect("future polled after completion");
|
||||
|
||||
s.get_mut().context = cx as *mut _ as usize;
|
||||
mid_handshake.get_mut().set_waker(Some(ctx));
|
||||
|
||||
match s.handshake() {
|
||||
Ok(mut s) => {
|
||||
s.get_mut().context = 0;
|
||||
match mid_handshake.handshake() {
|
||||
Ok(mut stream) => {
|
||||
stream.get_mut().set_waker(None);
|
||||
|
||||
Poll::Ready(Ok(SslStream(s)))
|
||||
Poll::Ready(Ok(SslStream(stream)))
|
||||
}
|
||||
Err(ssl::HandshakeError::WouldBlock(mut s)) => {
|
||||
s.get_mut().context = 0;
|
||||
Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
|
||||
mid_handshake.get_mut().set_waker(None);
|
||||
|
||||
self.0 = Some(s);
|
||||
self.0 = Some(mid_handshake);
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
Err(e) => Poll::Ready(Err(HandshakeError(e))),
|
||||
Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
|
||||
mid_handshake.get_mut().set_waker(None);
|
||||
|
||||
Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
|
||||
mid_handshake,
|
||||
))))
|
||||
}
|
||||
Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
|
||||
Poll::Ready(Err(HandshakeError(err)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue