From 83e049d8d983e84dc4ef62575b80b6f3889c206e Mon Sep 17 00:00:00 2001 From: 0x676e67 Date: Mon, 7 Jul 2025 21:10:50 +0800 Subject: [PATCH] feat(boring): adapt `boring2` for compio async runtime (#85) close: https://github.com/0x676e67/boring2/issues/78 --- .github/workflows/ci.yml | 8 +- Cargo.toml | 4 + README.md | 1 + compio-boring/Cargo.toml | 46 ++++ compio-boring/src/async_callbacks.rs | 77 +++++++ compio-boring/src/lib.rs | 331 +++++++++++++++++++++++++++ compio-boring/tests/client_server.rs | 32 +++ 7 files changed, 496 insertions(+), 3 deletions(-) create mode 100644 compio-boring/Cargo.toml create mode 100644 compio-boring/src/async_callbacks.rs create mode 100644 compio-boring/src/lib.rs create mode 100644 compio-boring/tests/client_server.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 414e4b81..82dd72d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -141,6 +141,7 @@ jobs: rust: stable os: ubuntu-latest apt_packages: gcc-multilib g++-multilib + extra_test_args: --workspace --exclude compio-boring2 - thing: arm-linux target: arm-unknown-linux-gnueabi rust: stable @@ -151,6 +152,7 @@ jobs: CC: arm-linux-gnueabi-gcc CXX: arm-linux-gnueabi-g++ CARGO_TARGET_ARM_UNKNOWN_LINUX_GNUEABI_LINKER: arm-linux-gnueabi-g++ + extra_test_args: --workspace --exclude compio-boring2 - thing: aarch64-linux target: aarch64-unknown-linux-gnu rust: stable @@ -182,19 +184,19 @@ jobs: CPLUS_INCLUDE_PATH: "C:\\msys64\\usr\\include" LIBRARY_PATH: "C:\\msys64\\usr\\lib" # CI's Windows doesn't have required root certs - extra_test_args: --workspace --exclude tokio-boring2 --exclude hyper-boring2 + extra_test_args: --workspace --exclude tokio-boring2 --exclude compio-boring2 - thing: i686-msvc target: i686-pc-windows-msvc rust: stable-x86_64-msvc os: windows-latest # CI's Windows doesn't have required root certs - extra_test_args: --workspace --exclude tokio-boring2 --exclude hyper-boring2 + extra_test_args: --workspace --exclude tokio-boring2 --exclude compio-boring2 - thing: x86_64-msvc target: x86_64-pc-windows-msvc rust: stable-x86_64-msvc os: windows-latest # CI's Windows doesn't have required root certs - extra_test_args: --workspace --exclude tokio-boring2 --exclude hyper-boring2 + extra_test_args: --workspace --exclude tokio-boring2 --exclude compio-boring2 steps: - uses: actions/checkout@v4 diff --git a/Cargo.toml b/Cargo.toml index de24da74..160c39cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "boring", "boring-sys", "tokio-boring", + "compio-boring", ] resolver = "2" @@ -21,6 +22,7 @@ publish = false boring-sys = { package = "boring-sys2", version = "5.0.0-alpha.3", path = "./boring-sys" } boring = { package = "boring2", version = "5.0.0-alpha.3", path = "./boring" } tokio-boring = { package = "tokio-boring2", version = "5.0.0-alpha.3", path = "./tokio-boring" } +compio-boring = { package = "compio-boring2", version = "0.1.0-alpha.1", path = "./compio-boring" } bindgen = { version = "0.72.0", default-features = false, features = ["runtime"] } bytes = "1" @@ -40,3 +42,5 @@ linked_hash_set = "0.1" openssl-macros = "0.1.1" autocfg = "1.3.0" brotli = "8" +compio = { version = "0.15.0" } +compio-io = { version = "0.7.0" } \ No newline at end of file diff --git a/README.md b/README.md index 9e5911a7..087887b8 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ This package implements only the TLS extensions specification and supports the o ## Documentation - Boring API: - tokio TLS adapters: + - compio TLS adapters: - FFI bindings: ## Contribution diff --git a/compio-boring/Cargo.toml b/compio-boring/Cargo.toml new file mode 100644 index 00000000..36c7bd43 --- /dev/null +++ b/compio-boring/Cargo.toml @@ -0,0 +1,46 @@ +[package] +name = "compio-boring2" +version = { workspace = true } +authors = ["0x676e67 "] +license = "MIT OR Apache-2.0" +edition = { workspace = true } +repository = { workspace = true } +documentation = "https://docs.rs/compio-boring2" +description = """ +An implementation of SSL streams for Compio backed by BoringSSL +""" + +[package.metadata.docs.rs] +features = ["pq-experimental"] +rustdoc-args = ["--cfg", "docsrs"] + +[features] +# Use a FIPS-validated version of boringssl. +fips = ["boring/fips", "boring-sys/fips"] + +# Use a FIPS build of BoringSSL, but don't set "fips-compat". +# +# As of boringSSL commit a430310d6563c0734ddafca7731570dfb683dc19, we no longer +# need to make exceptions for the types of BufLen, ProtosLen, and ValueLen, +# which means the "fips-compat" feature is no longer needed. +# +# TODO(cjpatton) Delete this feature and modify "fips" so that it doesn't imply +# "fips-compat". +fips-precompiled = ["boring/fips-precompiled"] + +# Link with precompiled FIPS-validated `bcm.o` module. +fips-link-precompiled = ["boring/fips-link-precompiled", "boring-sys/fips-link-precompiled"] + +# Enables experimental post-quantum crypto (https://blog.cloudflare.com/post-quantum-for-all/) +pq-experimental = ["boring/pq-experimental"] + +[dependencies] +boring = { workspace = true } +boring-sys = { workspace = true } +compio = { workspace = true } +compio-io = { workspace = true, features = ["compat"]} + +[dev-dependencies] +futures = { workspace = true } +compio = { workspace = true, features = [ "macros"] } +anyhow = { workspace = true } diff --git a/compio-boring/src/async_callbacks.rs b/compio-boring/src/async_callbacks.rs new file mode 100644 index 00000000..7d3888be --- /dev/null +++ b/compio-boring/src/async_callbacks.rs @@ -0,0 +1,77 @@ +use boring::ssl::{ + AsyncPrivateKeyMethod, AsyncSelectCertError, BoxGetSessionFuture, BoxSelectCertFuture, + ClientHello, SslContextBuilder, SslRef, +}; + +/// Extensions to [`SslContextBuilder`]. +/// +/// This trait provides additional methods to use async callbacks with boring. +pub trait SslContextBuilderExt: private::Sealed { + /// Sets a callback that is called before most [`ClientHello`] processing + /// and before the decision whether to resume a session is made. The + /// callback may inspect the [`ClientHello`] and configure the connection. + /// + /// This method uses a function that returns a future whose output is + /// itself a closure that will be passed [`ClientHello`] to configure + /// the connection based on the computations done in the future. + /// + /// See [`SslContextBuilder::set_select_certificate_callback`] for the sync + /// setter of this callback. + fn set_async_select_certificate_callback(&mut self, callback: F) + where + F: Fn(&mut ClientHello<'_>) -> Result + + Send + + Sync + + 'static; + + /// Configures a custom private key method on the context. + /// + /// See [`AsyncPrivateKeyMethod`] for more details. + fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod); + + /// Sets a callback that is called when a client proposed to resume a session + /// but it was not found in the internal cache. + /// + /// The callback is passed a reference to the session ID provided by the client. + /// It should return the session corresponding to that ID if available. This is + /// only used for servers, not clients. + /// + /// See [`SslContextBuilder::set_get_session_callback`] for the sync setter + /// of this callback. + /// + /// # Safety + /// + /// The returned [`SslSession`] must not be associated with a different [`SslContext`]. + unsafe fn set_async_get_session_callback(&mut self, callback: F) + where + F: Fn(&mut SslRef, &[u8]) -> Option + Send + Sync + 'static; +} + +impl SslContextBuilderExt for SslContextBuilder { + fn set_async_select_certificate_callback(&mut self, callback: F) + where + F: Fn(&mut ClientHello<'_>) -> Result + + Send + + Sync + + 'static, + { + self.set_async_select_certificate_callback(callback); + } + + fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) { + self.set_async_private_key_method(method); + } + + unsafe fn set_async_get_session_callback(&mut self, callback: F) + where + F: Fn(&mut SslRef, &[u8]) -> Option + Send + Sync + 'static, + { + self.set_async_get_session_callback(callback); + } +} + +mod private { + pub trait Sealed {} +} + +impl private::Sealed for SslContextBuilder {} diff --git a/compio-boring/src/lib.rs b/compio-boring/src/lib.rs new file mode 100644 index 00000000..36db63f8 --- /dev/null +++ b/compio-boring/src/lib.rs @@ -0,0 +1,331 @@ +//! Async TLS streams backed by BoringSSL +//! +//! This library is an implementation of TLS streams using BoringSSL for +//! negotiating the connection. Each TLS stream implements the `Read` and +//! `Write` traits to interact and interoperate with the rest of the futures I/O +//! ecosystem. Client connections initiated from this crate verify hostnames +//! automatically and by default. +//! +//! `tokio-boring` exports this ability through [`accept`] and [`connect`]. `accept` should +//! be used by servers, and `connect` by clients. These augment the functionality provided by the +//! [`boring`] crate, on which this crate is built. Configuration of TLS parameters is still +//! primarily done through the [`boring`] crate. +#![warn(missing_docs)] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] + +mod async_callbacks; + +use boring::ssl::{self, ConnectConfiguration, ErrorCode, SslAcceptor, SslRef}; +use boring_sys as ffi; +use compio::buf::{IoBuf, IoBufMut}; +use compio::io::{AsyncRead, AsyncWrite}; +use compio::BufResult; +use compio_io::compat::SyncStream; +use std::error::Error; +use std::fmt; +use std::io; +use std::mem::MaybeUninit; + +pub use crate::async_callbacks::SslContextBuilderExt; +pub use boring::ssl::{ + AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish, + BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish, + BoxSelectCertFuture, ExDataFuture, +}; + +/// Asynchronously performs a client-side TLS handshake over the provided stream. +/// +/// This function automatically sets the task waker on the `Ssl` from `config` to +/// allow to make use of async callbacks provided by the boring crate. +pub async fn connect( + config: ConnectConfiguration, + domain: &str, + stream: S, +) -> Result, HandshakeError> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + let res = config.connect(domain, SyncStream::new(stream)); + perform_tls_handshake(res).await +} + +/// Asynchronously performs a server-side TLS handshake over the provided stream. +/// +/// This function automatically sets the task waker on the `Ssl` from `config` to +/// allow to make use of async callbacks provided by the boring crate. +pub async fn accept(acceptor: &SslAcceptor, stream: S) -> Result, HandshakeError> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + let res = acceptor.accept(SyncStream::new(stream)); + perform_tls_handshake(res).await +} + +/// A partially constructed `SslStream`, useful for unusual handshakes. +pub struct SslStreamBuilder { + inner: ssl::SslStreamBuilder>, +} + +impl SslStreamBuilder +where + S: AsyncRead + AsyncWrite + Unpin, +{ + /// Begins creating an `SslStream` atop `stream`. + pub fn new(ssl: ssl::Ssl, stream: S) -> Self { + Self { + inner: ssl::SslStreamBuilder::new(ssl, SyncStream::new(stream)), + } + } + + /// Initiates a client-side TLS handshake. + pub async fn accept(self) -> Result, HandshakeError> { + let res = self.inner.connect(); + perform_tls_handshake(res).await + } + + /// Initiates a server-side TLS handshake. + pub async fn connect(self) -> Result, HandshakeError> { + let res = self.inner.connect(); + perform_tls_handshake(res).await + } +} + +impl SslStreamBuilder { + /// Returns a shared reference to the `Ssl` object associated with this builder. + #[must_use] + pub fn ssl(&self) -> &SslRef { + self.inner.ssl() + } + + /// Returns a mutable reference to the `Ssl` object associated with this builder. + pub fn ssl_mut(&mut self) -> &mut SslRef { + self.inner.ssl_mut() + } +} + +/// A wrapper around an underlying raw stream which implements the SSL +/// protocol. +/// +/// A `SslStream` represents a handshake that has been completed successfully +/// and both the server and the client are ready for receiving and sending +/// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written +/// to a `SslStream` are encrypted when passing through to `S`. +#[derive(Debug)] +pub struct SslStream(ssl::SslStream>); + +impl SslStream { + /// Returns a shared reference to the `Ssl` object associated with this stream. + #[must_use] + pub fn ssl(&self) -> &SslRef { + self.0.ssl() + } + + /// Returns a mutable reference to the `Ssl` object associated with this stream. + pub fn ssl_mut(&mut self) -> &mut SslRef { + self.0.ssl_mut() + } + + /// Returns a shared reference to the underlying stream. + #[must_use] + pub fn get_ref(&self) -> &S { + self.0.get_ref().get_ref() + } + + /// Returns a mutable reference to the underlying stream. + pub fn get_mut(&mut self) -> &mut S { + self.0.get_mut().get_mut() + } +} + +impl SslStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + /// Constructs an `SslStream` from a pointer to the underlying OpenSSL `SSL` struct. + /// + /// This is useful if the handshake has already been completed elsewhere. + /// + /// # Safety + /// + /// The caller must ensure the pointer is valid. + pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self { + Self(ssl::SslStream::from_raw_parts(ssl, SyncStream::new(stream))) + } +} + +impl AsyncRead for SslStream { + async fn read(&mut self, mut buf: B) -> BufResult { + let slice = buf.as_mut_slice(); + + let mut f = { + slice.fill(MaybeUninit::new(0)); + // SAFETY: The memory has been initialized + let slice = + unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) }; + |s: &mut _| std::io::Read::read(s, slice) + }; + + loop { + match f(&mut self.0) { + Ok(res) => { + unsafe { buf.set_buf_init(res) }; + return BufResult(Ok(res), buf); + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + match self.0.get_mut().fill_read_buf().await { + Ok(_) => continue, + Err(e) => return BufResult(Err(e), buf), + } + } + res => return BufResult(res, buf), + } + } + } + + // OpenSSL does not support vectored reads +} + +/// `AsyncRead` is needed for shutting down stream. +impl AsyncWrite for SslStream { + async fn write(&mut self, buf: T) -> BufResult { + let slice = buf.as_slice(); + loop { + let res = io::Write::write(&mut self.0, slice); + match res { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.flush().await { + Ok(_) => continue, + Err(e) => return BufResult(Err(e), buf), + }, + _ => return BufResult(res, buf), + } + } + } + + async fn flush(&mut self) -> io::Result<()> { + loop { + match io::Write::flush(&mut self.0) { + Ok(()) => break, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + self.0.get_mut().flush_write_buf().await?; + } + Err(e) => return Err(e), + } + } + self.0.get_mut().flush_write_buf().await?; + Ok(()) + } + + async fn shutdown(&mut self) -> io::Result<()> { + self.flush().await?; + self.0.get_mut().get_mut().shutdown().await + } +} + +/// The error type returned after a failed handshake. +pub enum HandshakeError { + /// An error that occurred during the handshake. + Inner(ssl::HandshakeError>), + /// An I/O error that occurred during the handshake. + Io(io::Error), +} + +impl HandshakeError { + /// Returns a shared reference to the `Ssl` object associated with this error. + #[must_use] + pub fn ssl(&self) -> Option<&SslRef> { + match self { + HandshakeError::Inner(ssl::HandshakeError::Failure(s)) => Some(s.ssl()), + _ => None, + } + } + + /// Returns the error code, if any. + #[must_use] + pub fn code(&self) -> Option { + match self { + HandshakeError::Inner(ssl::HandshakeError::Failure(s)) => Some(s.error().code()), + _ => None, + } + } + + /// Returns a reference to the inner I/O error, if any. + #[must_use] + pub fn as_io_error(&self) -> Option<&io::Error> { + match self { + HandshakeError::Inner(ssl::HandshakeError::Failure(s)) => s.error().io_error(), + HandshakeError::Io(e) => Some(e), + _ => None, + } + } +} + +impl fmt::Debug for HandshakeError +where + S: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + HandshakeError::Inner(e) => fmt::Debug::fmt(e, fmt), + HandshakeError::Io(e) => fmt::Debug::fmt(e, fmt), + } + } +} + +impl fmt::Display for HandshakeError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + HandshakeError::Inner(e) => fmt::Display::fmt(e, fmt), + HandshakeError::Io(e) => fmt::Display::fmt(e, fmt), + } + } +} + +impl Error for HandshakeError +where + S: fmt::Debug, +{ + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + HandshakeError::Inner(e) => e.source(), + HandshakeError::Io(e) => Some(e), + } + } +} + +async fn perform_tls_handshake( + mut res: Result>, ssl::HandshakeError>>, +) -> Result, HandshakeError> { + loop { + match res { + Ok(mut s) => { + s.get_mut() + .flush_write_buf() + .await + .map_err(HandshakeError::Io)?; + return Ok(SslStream(s)); + } + Err(e) => match e { + ssl::HandshakeError::Failure(_) => return Err(HandshakeError::Inner(e)), + ssl::HandshakeError::SetupFailure(_) => { + return Err(HandshakeError::Inner(e)); + } + ssl::HandshakeError::WouldBlock(mut mid_stream) => { + if mid_stream + .get_mut() + .flush_write_buf() + .await + .map_err(HandshakeError::Io)? + == 0 + { + mid_stream + .get_mut() + .fill_read_buf() + .await + .map_err(HandshakeError::Io)?; + } + res = mid_stream.handshake(); + } + }, + } + } +} diff --git a/compio-boring/tests/client_server.rs b/compio-boring/tests/client_server.rs new file mode 100644 index 00000000..333fed0c --- /dev/null +++ b/compio-boring/tests/client_server.rs @@ -0,0 +1,32 @@ +use boring::ssl::{SslConnector, SslMethod}; +use compio::io::AsyncReadExt; +use compio::net::TcpStream; +use compio_io::AsyncWrite; +use std::net::ToSocketAddrs; + +#[compio::test] +async fn google() { + let addr = "google.com:443".to_socket_addrs().unwrap().next().unwrap(); + let stream = TcpStream::connect(&addr).await.unwrap(); + + let config = SslConnector::builder(SslMethod::tls()) + .unwrap() + .build() + .configure() + .unwrap(); + + let mut stream = compio_boring2::connect(config, "google.com", stream) + .await + .unwrap(); + + stream.write(b"GET / HTTP/1.0\r\n\r\n").await.unwrap(); + stream.flush().await.unwrap(); + let (_, buf) = stream.read_to_end(vec![]).await.unwrap(); + stream.shutdown().await.unwrap(); + let response = String::from_utf8_lossy(&buf); + let response = response.trim_end(); + + // any response code is fine + assert!(response.starts_with("HTTP/1.0 ")); + assert!(response.ends_with("") || response.ends_with("")); +}