diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 9e98812f..0ac912c4 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -1591,6 +1591,26 @@ pub struct CipherBits { #[repr(transparent)] pub struct ClientHello(ffi::SSL_CLIENT_HELLO); +impl ClientHello { + /// Returns the data of a given extension, if present. + /// + /// This corresponds to [`SSL_early_callback_ctx_extension_get`]. + /// + /// [`SSL_early_callback_ctx_extension_get`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_early_callback_ctx_extension_get + pub fn get_extension(&self, ext_type: u16) -> Option<&[u8]> { + unsafe { + let mut ptr = ptr::null(); + let mut len = 0; + let result = + ffi::SSL_early_callback_ctx_extension_get(&self.0, ext_type, &mut ptr, &mut len); + if result == 0 { + return None; + } + Some(slice::from_raw_parts(ptr, len)) + } + } +} + /// Information about a cipher. pub struct SslCipher(*mut ffi::SSL_CIPHER); diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index 0dc663f9..4f242a32 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -505,6 +505,53 @@ fn test_select_cert_error() { client.connect_err(); } +#[test] +fn test_select_cert_unknown_extension() { + let mut server = Server::builder(); + let unknown_extension = std::sync::Arc::new(std::sync::Mutex::new(None)); + + server.ctx().set_select_certificate_callback({ + let unknown = unknown_extension.clone(); + move |client_hello| { + *unknown.lock().unwrap() = client_hello.get_extension(1337).map(ToOwned::to_owned); + Ok(()) + } + }); + + let server = server.build(); + let client = server.client(); + + client.connect(); + assert_eq!(unknown_extension.lock().unwrap().as_deref(), None); +} + +#[test] +fn test_select_cert_alpn_extension() { + let mut server = Server::builder(); + let alpn_extension = std::sync::Arc::new(std::sync::Mutex::new(None)); + server.ctx().set_select_certificate_callback({ + let alpn = alpn_extension.clone(); + move |client_hello| { + *alpn.lock().unwrap() = Some( + client_hello + .get_extension(ffi::TLSEXT_TYPE_application_layer_protocol_negotiation as u16) + .unwrap() + .to_owned(), + ); + Ok(()) + } + }); + let server = server.build(); + + let mut client = server.client(); + client.ctx().set_alpn_protos(b"\x06http/2").unwrap(); + client.connect(); + assert_eq!( + alpn_extension.lock().unwrap().as_deref(), + Some(&b"\x00\x07\x06http/2"[..]), + ); +} + #[test] #[should_panic(expected = "blammo")] fn write_panic() {