diff --git a/boring/src/ssl/callbacks.rs b/boring/src/ssl/callbacks.rs index eb595019..64e7a05e 100644 --- a/boring/src/ssl/callbacks.rs +++ b/boring/src/ssl/callbacks.rs @@ -12,6 +12,7 @@ use std::sync::Arc; use error::ErrorStack; use ssl::AlpnError; +use ssl::{ClientHello, SelectCertError}; use ssl::{ SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef, SESSION_CTX_INDEX, @@ -190,6 +191,25 @@ where } } +pub unsafe extern "C" fn raw_select_cert( + client_hello: *const ffi::SSL_CLIENT_HELLO, +) -> ffi::ssl_select_cert_result_t +where + F: Fn(&ClientHello) -> Result<(), SelectCertError> + Sync + Send + 'static, +{ + let ssl = SslRef::from_ptr_mut((*client_hello).ssl); + let client_hello = &*(client_hello as *const ClientHello); + let callback = ssl + .ssl_context() + .ex_data(SslContext::cached_ex_index::()) + .expect("BUG: select cert callback missing") as *const F; + + match (*callback)(client_hello) { + Ok(()) => ffi::ssl_select_cert_result_t::ssl_select_cert_success, + Err(e) => e.0, + } +} + pub unsafe extern "C" fn raw_tlsext_status(ssl: *mut ffi::SSL, _: *mut c_void) -> c_int where F: Fn(&mut SslRef) -> Result + 'static + Sync + Send, diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 25e3f7ca..a71eeedf 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -463,6 +463,62 @@ impl AlpnError { pub const NOACK: AlpnError = AlpnError(ffi::SSL_TLSEXT_ERR_NOACK); } +/// An error returned from a certificate selection callback. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct SelectCertError(ffi::ssl_select_cert_result_t); + +impl SelectCertError { + /// A fatal error occured and the handshake should be terminated. + pub const ERROR: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_error); +} + +/// Extension types, to be used with `ClientHello::get_extension`. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct ExtensionType(u16); + +impl ExtensionType { + pub const SERVER_NAME: Self = Self(ffi::TLSEXT_TYPE_server_name as u16); + pub const STATUS_REQUEST: Self = Self(ffi::TLSEXT_TYPE_status_request as u16); + pub const EC_POINT_FORMATS: Self = Self(ffi::TLSEXT_TYPE_ec_point_formats as u16); + pub const SIGNATURE_ALGORITHMS: Self = Self(ffi::TLSEXT_TYPE_signature_algorithms as u16); + pub const SRTP: Self = Self(ffi::TLSEXT_TYPE_srtp as u16); + pub const APPLICATION_LAYER_PROTOCOL_NEGOTIATION: Self = + Self(ffi::TLSEXT_TYPE_application_layer_protocol_negotiation as u16); + pub const PADDING: Self = Self(ffi::TLSEXT_TYPE_padding as u16); + pub const EXTENDED_MASTER_SECRET: Self = Self(ffi::TLSEXT_TYPE_extended_master_secret as u16); + pub const TOKEN_BINDING: Self = Self(ffi::TLSEXT_TYPE_token_binding as u16); + pub const QUIC_TRANSPORT_PARAMETERS_LEGACY: Self = + Self(ffi::TLSEXT_TYPE_quic_transport_parameters_legacy as u16); + pub const QUIC_TRANSPORT_PARAMETERS_STANDARD: Self = + Self(ffi::TLSEXT_TYPE_quic_transport_parameters_standard as u16); + pub const CERT_COMPRESSION: Self = Self(ffi::TLSEXT_TYPE_cert_compression as u16); + pub const SESSION_TICKET: Self = Self(ffi::TLSEXT_TYPE_session_ticket as u16); + pub const SUPPORTED_GROUPS: Self = Self(ffi::TLSEXT_TYPE_supported_groups as u16); + pub const PRE_SHARED_KEY: Self = Self(ffi::TLSEXT_TYPE_pre_shared_key as u16); + pub const EARLY_DATA: Self = Self(ffi::TLSEXT_TYPE_early_data as u16); + pub const SUPPORTED_VERSIONS: Self = Self(ffi::TLSEXT_TYPE_supported_versions as u16); + pub const COOKIE: Self = Self(ffi::TLSEXT_TYPE_cookie as u16); + pub const PSK_KEY_EXCHANGE_MODES: Self = Self(ffi::TLSEXT_TYPE_psk_key_exchange_modes as u16); + pub const CERTIFICATE_AUTHORITIES: Self = Self(ffi::TLSEXT_TYPE_certificate_authorities as u16); + pub const SIGNATURE_ALGORITHMS_CERT: Self = + Self(ffi::TLSEXT_TYPE_signature_algorithms_cert as u16); + pub const KEY_SHARE: Self = Self(ffi::TLSEXT_TYPE_key_share as u16); + pub const RENEGOTIATE: Self = Self(ffi::TLSEXT_TYPE_renegotiate as u16); + pub const DELEGATED_CREDENTIAL: Self = Self(ffi::TLSEXT_TYPE_delegated_credential as u16); + pub const APPLICATION_SETTINGS: Self = Self(ffi::TLSEXT_TYPE_application_settings as u16); + pub const ENCRYPTED_CLIENT_HELLO: Self = Self(ffi::TLSEXT_TYPE_encrypted_client_hello as u16); + pub const ECH_IS_INNER: Self = Self(ffi::TLSEXT_TYPE_ech_is_inner as u16); + pub const CERTIFICATE_TIMESTAMP: Self = Self(ffi::TLSEXT_TYPE_certificate_timestamp as u16); + pub const NEXT_PROTO_NEG: Self = Self(ffi::TLSEXT_TYPE_next_proto_neg as u16); + pub const CHANNEL_ID: Self = Self(ffi::TLSEXT_TYPE_channel_id as u16); +} + +impl From for ExtensionType { + fn from(value: u16) -> Self { + Self(value) + } +} + /// An SSL/TLS protocol version. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct SslVersion(u16); @@ -1084,6 +1140,25 @@ impl SslContextBuilder { ); } } + /// Sets a callback that is called before most ClientHello processing and before the decision whether + /// to resume a session is made. The callback may inspect the ClientHello and configure the + /// connection. + /// + /// This corresponds to [`SSL_CTX_set_select_certificate_cb`]. + /// + /// [`SSL_CTX_set_select_certificate_cb`]: https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_set_select_certificate_cb.html + pub fn set_select_certificate_callback(&mut self, callback: F) + where + F: Fn(&ClientHello) -> Result<(), SelectCertError> + Sync + Send + 'static, + { + unsafe { + self.set_ex_data(SslContext::cached_ex_index::(), callback); + ffi::SSL_CTX_set_select_certificate_cb( + self.as_ptr(), + Some(callbacks::raw_select_cert::), + ); + } + } /// Checks for consistency between the private key and certificate. /// @@ -1560,6 +1635,29 @@ pub struct CipherBits { pub algorithm: i32, } +#[repr(transparent)] +pub struct ClientHello(ffi::SSL_CLIENT_HELLO); + +impl ClientHello { + /// Returns the data of a given extension, if present. + /// + /// This corresponds to [`SSL_early_callback_ctx_extension_get`]. + /// + /// [`SSL_early_callback_ctx_extension_get`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_early_callback_ctx_extension_get + pub fn get_extension(&self, ext_type: ExtensionType) -> Option<&[u8]> { + unsafe { + let mut ptr = ptr::null(); + let mut len = 0; + let result = + ffi::SSL_early_callback_ctx_extension_get(&self.0, ext_type.0, &mut ptr, &mut len); + if result == 0 { + return None; + } + Some(slice::from_raw_parts(ptr, len)) + } + } +} + /// Information about a cipher. pub struct SslCipher(*mut ffi::SSL_CIPHER); diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index 35e43dfd..7fddcd0b 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -25,9 +25,10 @@ use ssl; use ssl::test::server::Server; use ssl::SslVersion; use ssl::{ - Error, HandshakeError, MidHandshakeSslStream, ShutdownResult, ShutdownState, Ssl, SslAcceptor, - SslAcceptorBuilder, SslConnector, SslContext, SslContextBuilder, SslFiletype, SslMethod, - SslOptions, SslSessionCacheMode, SslStream, SslStreamBuilder, SslVerifyMode, StatusType, + Error, ExtensionType, HandshakeError, MidHandshakeSslStream, ShutdownResult, ShutdownState, + Ssl, SslAcceptor, SslAcceptorBuilder, SslConnector, SslContext, SslContextBuilder, SslFiletype, + SslMethod, SslOptions, SslSessionCacheMode, SslStream, SslStreamBuilder, SslVerifyMode, + StatusType, }; use x509::store::X509StoreBuilder; use x509::verify::X509CheckFlags; @@ -480,6 +481,80 @@ fn test_alpn_server_unilateral() { assert_eq!(None, s.ssl().selected_alpn_protocol()); } +#[test] +fn test_select_cert_ok() { + let mut server = Server::builder(); + server + .ctx() + .set_select_certificate_callback(|_client_hello| Ok(())); + let server = server.build(); + + let client = server.client(); + client.connect(); +} + +#[test] +fn test_select_cert_error() { + let mut server = Server::builder(); + server.should_error(); + server + .ctx() + .set_select_certificate_callback(|_client_hello| Err(ssl::SelectCertError::ERROR)); + let server = server.build(); + + let client = server.client(); + client.connect_err(); +} + +#[test] +fn test_select_cert_unknown_extension() { + let mut server = Server::builder(); + let unknown_extension = std::sync::Arc::new(std::sync::Mutex::new(None)); + + server.ctx().set_select_certificate_callback({ + let unknown = unknown_extension.clone(); + move |client_hello| { + *unknown.lock().unwrap() = client_hello + .get_extension(ExtensionType::QUIC_TRANSPORT_PARAMETERS_LEGACY) + .map(ToOwned::to_owned); + Ok(()) + } + }); + + let server = server.build(); + let client = server.client(); + + client.connect(); + assert_eq!(unknown_extension.lock().unwrap().as_deref(), None); +} + +#[test] +fn test_select_cert_alpn_extension() { + let mut server = Server::builder(); + let alpn_extension = std::sync::Arc::new(std::sync::Mutex::new(None)); + server.ctx().set_select_certificate_callback({ + let alpn = alpn_extension.clone(); + move |client_hello| { + *alpn.lock().unwrap() = Some( + client_hello + .get_extension(ExtensionType::APPLICATION_LAYER_PROTOCOL_NEGOTIATION) + .unwrap() + .to_owned(), + ); + Ok(()) + } + }); + let server = server.build(); + + let mut client = server.client(); + client.ctx().set_alpn_protos(b"\x06http/2").unwrap(); + client.connect(); + assert_eq!( + alpn_extension.lock().unwrap().as_deref(), + Some(&b"\x00\x07\x06http/2"[..]), + ); +} + #[test] #[should_panic(expected = "blammo")] fn write_panic() {