Introduce AsyncStreamBridge
This encapsulates a bit better the unsafety of task context management to invoke async code from inside boring.
This commit is contained in:
parent
97e2a8bc30
commit
1c790f7277
|
|
@ -0,0 +1,86 @@
|
||||||
|
//! Bridge between sync IO traits and async tokio IO traits.
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
|
use std::io;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll, Waker};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
|
||||||
|
pub(crate) struct AsyncStreamBridge<S> {
|
||||||
|
pub(crate) stream: S,
|
||||||
|
waker: Option<Waker>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> AsyncStreamBridge<S> {
|
||||||
|
pub(crate) fn new(stream: S) -> Self
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
Self {
|
||||||
|
stream,
|
||||||
|
waker: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn set_waker(&mut self, ctx: Option<&mut Context<'_>>) {
|
||||||
|
self.waker = ctx.map(|ctx| ctx.waker().clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the bridge has no waker.
|
||||||
|
pub(crate) fn with_context<F, R>(&mut self, f: F) -> R
|
||||||
|
where
|
||||||
|
S: Unpin,
|
||||||
|
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
|
||||||
|
{
|
||||||
|
let mut ctx =
|
||||||
|
Context::from_waker(self.waker.as_ref().expect("missing task context pointer"));
|
||||||
|
|
||||||
|
f(&mut ctx, Pin::new(&mut self.stream))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> io::Read for AsyncStreamBridge<S>
|
||||||
|
where
|
||||||
|
S: AsyncRead + Unpin,
|
||||||
|
{
|
||||||
|
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||||
|
self.with_context(|ctx, stream| {
|
||||||
|
let mut buf = ReadBuf::new(buf);
|
||||||
|
|
||||||
|
match stream.poll_read(ctx, &mut buf)? {
|
||||||
|
Poll::Ready(()) => Ok(buf.filled().len()),
|
||||||
|
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> io::Write for AsyncStreamBridge<S>
|
||||||
|
where
|
||||||
|
S: AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||||
|
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) {
|
||||||
|
Poll::Ready(r) => r,
|
||||||
|
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> io::Result<()> {
|
||||||
|
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) {
|
||||||
|
Poll::Ready(r) => r,
|
||||||
|
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> fmt::Debug for AsyncStreamBridge<S>
|
||||||
|
where
|
||||||
|
S: fmt::Debug,
|
||||||
|
{
|
||||||
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
fmt::Debug::fmt(&self.stream, fmt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -27,6 +27,10 @@ use std::pin::Pin;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
|
||||||
|
mod bridge;
|
||||||
|
|
||||||
|
use self::bridge::AsyncStreamBridge;
|
||||||
|
|
||||||
/// Asynchronously performs a client-side TLS handshake over the provided stream.
|
/// Asynchronously performs a client-side TLS handshake over the provided stream.
|
||||||
pub async fn connect<S>(
|
pub async fn connect<S>(
|
||||||
config: ConnectConfiguration,
|
config: ConnectConfiguration,
|
||||||
|
|
@ -48,94 +52,18 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handshake<S>(
|
async fn handshake<S>(
|
||||||
f: impl FnOnce(StreamWrapper<S>) -> Result<MidHandshakeSslStream<StreamWrapper<S>>, ErrorStack>,
|
f: impl FnOnce(
|
||||||
|
AsyncStreamBridge<S>,
|
||||||
|
) -> Result<MidHandshakeSslStream<AsyncStreamBridge<S>>, ErrorStack>,
|
||||||
stream: S,
|
stream: S,
|
||||||
) -> Result<SslStream<S>, HandshakeError<S>>
|
) -> Result<SslStream<S>, HandshakeError<S>>
|
||||||
where
|
where
|
||||||
S: AsyncRead + AsyncWrite + Unpin,
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
{
|
{
|
||||||
let ongoing_handshake = Some(
|
let mid_handshake = f(AsyncStreamBridge::new(stream))
|
||||||
f(StreamWrapper { stream, context: 0 })
|
.map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
|
||||||
.map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?,
|
|
||||||
);
|
|
||||||
|
|
||||||
HandshakeFuture(ongoing_handshake).await
|
HandshakeFuture(Some(mid_handshake)).await
|
||||||
}
|
|
||||||
|
|
||||||
struct StreamWrapper<S> {
|
|
||||||
stream: S,
|
|
||||||
context: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> StreamWrapper<S> {
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// Must be called with `context` set to a valid pointer to a live `Context` object, and the
|
|
||||||
/// wrapper must be pinned in memory.
|
|
||||||
unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
|
|
||||||
debug_assert_ne!(self.context, 0);
|
|
||||||
let stream = Pin::new_unchecked(&mut self.stream);
|
|
||||||
let context = &mut *(self.context as *mut _);
|
|
||||||
(stream, context)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> fmt::Debug for StreamWrapper<S>
|
|
||||||
where
|
|
||||||
S: fmt::Debug,
|
|
||||||
{
|
|
||||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
fmt::Debug::fmt(&self.stream, fmt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> StreamWrapper<S>
|
|
||||||
where
|
|
||||||
S: Unpin,
|
|
||||||
{
|
|
||||||
fn with_context<F, R>(&mut self, f: F) -> R
|
|
||||||
where
|
|
||||||
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
|
|
||||||
{
|
|
||||||
unsafe {
|
|
||||||
assert_ne!(self.context, 0);
|
|
||||||
let waker = &mut *(self.context as *mut _);
|
|
||||||
f(waker, Pin::new(&mut self.stream))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> Read for StreamWrapper<S>
|
|
||||||
where
|
|
||||||
S: AsyncRead + Unpin,
|
|
||||||
{
|
|
||||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
|
||||||
let (stream, cx) = unsafe { self.parts() };
|
|
||||||
let mut buf = ReadBuf::new(buf);
|
|
||||||
match stream.poll_read(cx, &mut buf)? {
|
|
||||||
Poll::Ready(()) => Ok(buf.filled().len()),
|
|
||||||
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> Write for StreamWrapper<S>
|
|
||||||
where
|
|
||||||
S: AsyncWrite + Unpin,
|
|
||||||
{
|
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
|
||||||
match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) {
|
|
||||||
Poll::Ready(r) => r,
|
|
||||||
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
|
||||||
match self.with_context(|ctx, stream| stream.poll_flush(ctx)) {
|
|
||||||
Poll::Ready(r) => r,
|
|
||||||
Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
|
fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
|
||||||
|
|
@ -154,7 +82,7 @@ fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
|
||||||
/// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written
|
/// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written
|
||||||
/// to a `SslStream` are encrypted when passing through to `S`.
|
/// to a `SslStream` are encrypted when passing through to `S`.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
|
pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
|
||||||
|
|
||||||
impl<S> SslStream<S> {
|
impl<S> SslStream<S> {
|
||||||
/// Returns a shared reference to the `Ssl` object associated with this stream.
|
/// Returns a shared reference to the `Ssl` object associated with this stream.
|
||||||
|
|
@ -172,14 +100,21 @@ impl<S> SslStream<S> {
|
||||||
&mut self.0.get_mut().stream
|
&mut self.0.get_mut().stream
|
||||||
}
|
}
|
||||||
|
|
||||||
fn with_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
|
fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
|
||||||
where
|
where
|
||||||
F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
|
F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
|
||||||
{
|
{
|
||||||
self.0.get_mut().context = ctx as *mut _ as usize;
|
self.0.get_mut().set_waker(Some(ctx));
|
||||||
let r = f(&mut self.0);
|
|
||||||
self.0.get_mut().context = 0;
|
let result = f(&mut self.0);
|
||||||
r
|
|
||||||
|
// NOTE(nox): This should also be executed when `f` panics,
|
||||||
|
// but it's not that important as boring segfaults on panics
|
||||||
|
// and we always set the context prior to doing anything with
|
||||||
|
// the inner async stream.
|
||||||
|
self.0.get_mut().set_waker(None);
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -195,8 +130,10 @@ where
|
||||||
///
|
///
|
||||||
/// The caller must ensure the pointer is valid.
|
/// The caller must ensure the pointer is valid.
|
||||||
pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self {
|
pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self {
|
||||||
let stream = StreamWrapper { stream, context: 0 };
|
Self(ssl::SslStream::from_raw_parts(
|
||||||
SslStream(ssl::SslStream::from_raw_parts(ssl, stream))
|
ssl,
|
||||||
|
AsyncStreamBridge::new(stream),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -209,7 +146,7 @@ where
|
||||||
ctx: &mut Context<'_>,
|
ctx: &mut Context<'_>,
|
||||||
buf: &mut ReadBuf,
|
buf: &mut ReadBuf,
|
||||||
) -> Poll<io::Result<()>> {
|
) -> Poll<io::Result<()>> {
|
||||||
self.with_context(ctx, |s| {
|
self.run_in_context(ctx, |s| {
|
||||||
// This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though
|
// This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though
|
||||||
// OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now.
|
// OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now.
|
||||||
let slice = unsafe {
|
let slice = unsafe {
|
||||||
|
|
@ -239,15 +176,15 @@ where
|
||||||
ctx: &mut Context,
|
ctx: &mut Context,
|
||||||
buf: &[u8],
|
buf: &[u8],
|
||||||
) -> Poll<io::Result<usize>> {
|
) -> Poll<io::Result<usize>> {
|
||||||
self.with_context(ctx, |s| cvt(s.write(buf)))
|
self.run_in_context(ctx, |s| cvt(s.write(buf)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
|
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
self.with_context(ctx, |s| cvt(s.flush()))
|
self.run_in_context(ctx, |s| cvt(s.flush()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
|
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
match self.with_context(ctx, |s| s.shutdown()) {
|
match self.run_in_context(ctx, |s| s.shutdown()) {
|
||||||
Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
|
Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
|
||||||
Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
|
Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
|
||||||
Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
|
Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
|
||||||
|
|
@ -265,7 +202,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The error type returned after a failed handshake.
|
/// The error type returned after a failed handshake.
|
||||||
pub struct HandshakeError<S>(ssl::HandshakeError<StreamWrapper<S>>);
|
pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
|
||||||
|
|
||||||
impl<S> HandshakeError<S> {
|
impl<S> HandshakeError<S> {
|
||||||
/// Returns a shared reference to the `Ssl` object associated with this error.
|
/// Returns a shared reference to the `Ssl` object associated with this error.
|
||||||
|
|
@ -333,7 +270,10 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct HandshakeFuture<S>(Option<MidHandshakeSslStream<StreamWrapper<S>>>);
|
/// Future for an ongoing TLS handshake.
|
||||||
|
///
|
||||||
|
/// See [`connect`] and [`accept`].
|
||||||
|
pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
|
||||||
|
|
||||||
impl<S> Future for HandshakeFuture<S>
|
impl<S> Future for HandshakeFuture<S>
|
||||||
where
|
where
|
||||||
|
|
@ -341,25 +281,34 @@ where
|
||||||
{
|
{
|
||||||
type Output = Result<SslStream<S>, HandshakeError<S>>;
|
type Output = Result<SslStream<S>, HandshakeError<S>>;
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
let mut s = self.0.take().expect("future polled after completion");
|
let mut mid_handshake = self.0.take().expect("future polled after completion");
|
||||||
|
|
||||||
s.get_mut().context = cx as *mut _ as usize;
|
mid_handshake.get_mut().set_waker(Some(ctx));
|
||||||
|
|
||||||
match s.handshake() {
|
match mid_handshake.handshake() {
|
||||||
Ok(mut s) => {
|
Ok(mut stream) => {
|
||||||
s.get_mut().context = 0;
|
stream.get_mut().set_waker(None);
|
||||||
|
|
||||||
Poll::Ready(Ok(SslStream(s)))
|
Poll::Ready(Ok(SslStream(stream)))
|
||||||
}
|
}
|
||||||
Err(ssl::HandshakeError::WouldBlock(mut s)) => {
|
Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
|
||||||
s.get_mut().context = 0;
|
mid_handshake.get_mut().set_waker(None);
|
||||||
|
|
||||||
self.0 = Some(s);
|
self.0 = Some(mid_handshake);
|
||||||
|
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
Err(e) => Poll::Ready(Err(HandshakeError(e))),
|
Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
|
||||||
|
mid_handshake.get_mut().set_waker(None);
|
||||||
|
|
||||||
|
Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
|
||||||
|
mid_handshake,
|
||||||
|
))))
|
||||||
|
}
|
||||||
|
Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
|
||||||
|
Poll::Ready(Err(HandshakeError(err)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue