//! Async TLS streams backed by BoringSSL //! //! This library is an implementation of TLS streams using BoringSSL for //! negotiating the connection. Each TLS stream implements the `Read` and //! `Write` traits to interact and interoperate with the rest of the futures I/O //! ecosystem. Client connections initiated from this crate verify hostnames //! automatically and by default. //! //! This crate primarily exports this ability through two extension traits, //! `SslConnectorExt` and `SslAcceptorExt`. These traits augment the //! functionality provided by the [`boring` crate](https://github.com/cloudflare/boring) crate, //! on which this crate is built. Configuration of TLS parameters is still primarily done through //! the [`boring` crate](https://github.com/cloudflare/boring) #![warn(missing_docs)] use boring::ssl::{ self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor, SslRef, }; use boring_sys as ffi; use std::error::Error; use std::fmt; use std::future::Future; use std::io::{self, Read, Write}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; /// Asynchronously performs a client-side TLS handshake over the provided stream. pub async fn connect( config: ConnectConfiguration, domain: &str, stream: S, ) -> Result, HandshakeError> where S: AsyncRead + AsyncWrite + Unpin, { handshake(|s| config.connect(domain, s), stream).await } /// Asynchronously performs a server-side TLS handshake over the provided stream. pub async fn accept(acceptor: &SslAcceptor, stream: S) -> Result, HandshakeError> where S: AsyncRead + AsyncWrite + Unpin, { handshake(|s| acceptor.accept(s), stream).await } async fn handshake(f: F, stream: S) -> Result, HandshakeError> where F: FnOnce( StreamWrapper, ) -> Result>, ssl::HandshakeError>> + Unpin, S: AsyncRead + AsyncWrite + Unpin, { let start = StartHandshakeFuture(Some(StartHandshakeFutureInner { f, stream })); match start.await? { StartedHandshake::Done(s) => Ok(s), StartedHandshake::Mid(s) => HandshakeFuture(Some(s)).await, } } struct StreamWrapper { stream: S, context: usize, } impl StreamWrapper { /// # Safety /// /// Must be called with `context` set to a valid pointer to a live `Context` object, and the /// wrapper must be pinned in memory. unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) { debug_assert_ne!(self.context, 0); let stream = Pin::new_unchecked(&mut self.stream); let context = &mut *(self.context as *mut _); (stream, context) } } impl fmt::Debug for StreamWrapper where S: fmt::Debug, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(&self.stream, fmt) } } impl StreamWrapper where S: Unpin, { fn with_context(&mut self, f: F) -> R where F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R, { unsafe { assert_ne!(self.context, 0); let waker = &mut *(self.context as *mut _); f(waker, Pin::new(&mut self.stream)) } } } impl Read for StreamWrapper where S: AsyncRead + Unpin, { fn read(&mut self, buf: &mut [u8]) -> io::Result { let (stream, cx) = unsafe { self.parts() }; let mut buf = ReadBuf::new(buf); match stream.poll_read(cx, &mut buf)? { Poll::Ready(()) => Ok(buf.filled().len()), Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), } } } impl Write for StreamWrapper where S: AsyncWrite + Unpin, { fn write(&mut self, buf: &[u8]) -> io::Result { match self.with_context(|ctx, stream| stream.poll_write(ctx, buf)) { Poll::Ready(r) => r, Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), } } fn flush(&mut self) -> io::Result<()> { match self.with_context(|ctx, stream| stream.poll_flush(ctx)) { Poll::Ready(r) => r, Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), } } } fn cvt(r: io::Result) -> Poll> { match r { Ok(v) => Poll::Ready(Ok(v)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } /// A wrapper around an underlying raw stream which implements the SSL /// protocol. /// /// A `SslStream` represents a handshake that has been completed successfully /// and both the server and the client are ready for receiving and sending /// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written /// to a `SslStream` are encrypted when passing through to `S`. #[derive(Debug)] pub struct SslStream(ssl::SslStream>); impl SslStream { /// Returns a shared reference to the `Ssl` object associated with this stream. pub fn ssl(&self) -> &SslRef { self.0.ssl() } /// Returns a shared reference to the underlying stream. pub fn get_ref(&self) -> &S { &self.0.get_ref().stream } /// Returns a mutable reference to the underlying stream. pub fn get_mut(&mut self) -> &mut S { &mut self.0.get_mut().stream } fn with_context(&mut self, ctx: &mut Context<'_>, f: F) -> R where F: FnOnce(&mut ssl::SslStream>) -> R, { self.0.get_mut().context = ctx as *mut _ as usize; let r = f(&mut self.0); self.0.get_mut().context = 0; r } } impl SslStream where S: AsyncRead + AsyncWrite + Unpin, { /// Constructs an `SslStream` from a pointer to the underlying OpenSSL `SSL` struct. /// /// This is useful if the handshake has already been completed elsewhere. /// /// # Safety /// /// The caller must ensure the pointer is valid. pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self { let stream = StreamWrapper { stream, context: 0 }; SslStream(ssl::SslStream::from_raw_parts(ssl, stream)) } } impl AsyncRead for SslStream where S: AsyncRead + AsyncWrite + Unpin, { fn poll_read( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll> { self.with_context(ctx, |s| { // This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though // OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now. let slice = unsafe { let buf = buf.unfilled_mut(); std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast::(), buf.len()) }; match cvt(s.read(slice))? { Poll::Ready(nread) => { unsafe { buf.assume_init(nread); } buf.advance(nread); Poll::Ready(Ok(())) } Poll::Pending => Poll::Pending, } }) } } impl AsyncWrite for SslStream where S: AsyncRead + AsyncWrite + Unpin, { fn poll_write( mut self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8], ) -> Poll> { self.with_context(ctx, |s| cvt(s.write(buf))) } fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { self.with_context(ctx, |s| cvt(s.flush())) } fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { match self.with_context(ctx, |s| s.shutdown()) { Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {} Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {} Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => { return Poll::Pending; } Err(e) => { return Poll::Ready(Err(e .into_io_error() .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))); } } Pin::new(&mut self.0.get_mut().stream).poll_shutdown(ctx) } } /// The error type returned after a failed handshake. pub struct HandshakeError(ssl::HandshakeError>); impl HandshakeError { /// Returns a shared reference to the `Ssl` object associated with this error. pub fn ssl(&self) -> Option<&SslRef> { match &self.0 { ssl::HandshakeError::Failure(s) => Some(s.ssl()), _ => None, } } /// Converts error to the source data stream that was used for the handshake. pub fn into_source_stream(self) -> Option { match self.0 { ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream), _ => None, } } /// Returns a reference to the source data stream. pub fn as_source_stream(&self) -> Option<&S> { match &self.0 { ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream), _ => None, } } } impl fmt::Debug for HandshakeError where S: fmt::Debug, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Debug::fmt(&self.0, fmt) } } impl fmt::Display for HandshakeError where S: fmt::Debug, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.0, fmt) } } impl Error for HandshakeError where S: fmt::Debug, { fn source(&self) -> Option<&(dyn Error + 'static)> { self.0.source() } } enum StartedHandshake { Done(SslStream), Mid(MidHandshakeSslStream>), } struct StartHandshakeFuture(Option>); struct StartHandshakeFutureInner { f: F, stream: S, } impl Future for StartHandshakeFuture where F: FnOnce( StreamWrapper, ) -> Result>, ssl::HandshakeError>> + Unpin, S: Unpin, { type Output = Result, HandshakeError>; fn poll( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, ) -> Poll, HandshakeError>> { let inner = self.0.take().expect("future polled after completion"); let stream = StreamWrapper { stream: inner.stream, context: ctx as *mut _ as usize, }; match (inner.f)(stream) { Ok(mut s) => { s.get_mut().context = 0; Poll::Ready(Ok(StartedHandshake::Done(SslStream(s)))) } Err(ssl::HandshakeError::WouldBlock(mut s)) => { s.get_mut().context = 0; Poll::Ready(Ok(StartedHandshake::Mid(s))) } Err(e) => Poll::Ready(Err(HandshakeError(e))), } } } struct HandshakeFuture(Option>>); impl Future for HandshakeFuture where S: AsyncRead + AsyncWrite + Unpin, { type Output = Result, HandshakeError>; fn poll( mut self: Pin<&mut Self>, ctx: &mut Context<'_>, ) -> Poll, HandshakeError>> { let mut s = self.0.take().expect("future polled after completion"); s.get_mut().context = ctx as *mut _ as usize; match s.handshake() { Ok(mut s) => { s.get_mut().context = 0; Poll::Ready(Ok(SslStream(s))) } Err(ssl::HandshakeError::WouldBlock(mut s)) => { s.get_mut().context = 0; self.0 = Some(s); Poll::Pending } Err(e) => Poll::Ready(Err(HandshakeError(e))), } } }