From 0ffbdb030f9faec23f6576fea8f11797951344ae Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Wed, 2 Aug 2023 10:36:26 +0200 Subject: [PATCH] Implement SslContextBuilder::set_private_key_method --- boring/src/ssl/callbacks.rs | 111 ++++++++- boring/src/ssl/mod.rs | 96 ++++++++ boring/src/ssl/test/mod.rs | 32 ++- boring/src/ssl/test/private_key_method.rs | 282 ++++++++++++++++++++++ boring/src/ssl/test/server.rs | 33 ++- 5 files changed, 518 insertions(+), 36 deletions(-) create mode 100644 boring/src/ssl/test/private_key_method.rs diff --git a/boring/src/ssl/callbacks.rs b/boring/src/ssl/callbacks.rs index 1e7a64da..dc9f2d53 100644 --- a/boring/src/ssl/callbacks.rs +++ b/boring/src/ssl/callbacks.rs @@ -1,6 +1,13 @@ #![forbid(unsafe_op_in_unsafe_fn)] +use super::{ + AlpnError, ClientHello, PrivateKeyMethod, PrivateKeyMethodError, SelectCertError, SniError, + Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef, + SslSignatureAlgorithm, SESSION_CTX_INDEX, +}; +use crate::error::ErrorStack; use crate::ffi; +use crate::x509::{X509StoreContext, X509StoreContextRef}; use foreign_types::ForeignType; use foreign_types::ForeignTypeRef; use libc::c_char; @@ -12,19 +19,7 @@ use std::slice; use std::str; use std::sync::Arc; -use crate::error::ErrorStack; -use crate::ssl::AlpnError; -use crate::ssl::{ClientHello, SelectCertError}; -use crate::ssl::{ - SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef, - SESSION_CTX_INDEX, -}; -use crate::x509::{X509StoreContext, X509StoreContextRef}; - -pub(super) unsafe extern "C" fn raw_verify( - preverify_ok: c_int, - x509_ctx: *mut ffi::X509_STORE_CTX, -) -> c_int +pub extern "C" fn raw_verify(preverify_ok: c_int, x509_ctx: *mut ffi::X509_STORE_CTX) -> c_int where F: Fn(bool, &mut X509StoreContextRef) -> bool + 'static + Sync + Send, { @@ -372,3 +367,93 @@ where callback(ssl, line); } + +pub(super) unsafe extern "C" fn raw_sign( + ssl: *mut ffi::SSL, + out: *mut u8, + out_len: *mut usize, + max_out: usize, + signature_algorithm: u16, + in_: *const u8, + in_len: usize, +) -> ffi::ssl_private_key_result_t +where + M: PrivateKeyMethod, +{ + // SAFETY: boring provides valid inputs. + let input = unsafe { slice::from_raw_parts(in_, in_len) }; + + let signature_algorithm = SslSignatureAlgorithm(signature_algorithm); + + let callback = |method: &M, ssl: &mut _, output: &mut _| { + method.sign(ssl, input, signature_algorithm, output) + }; + + // SAFETY: boring provides valid inputs. + unsafe { raw_private_key_callback(ssl, out, out_len, max_out, callback) } +} + +pub(super) unsafe extern "C" fn raw_decrypt( + ssl: *mut ffi::SSL, + out: *mut u8, + out_len: *mut usize, + max_out: usize, + in_: *const u8, + in_len: usize, +) -> ffi::ssl_private_key_result_t +where + M: PrivateKeyMethod, +{ + // SAFETY: boring provides valid inputs. + let input = unsafe { slice::from_raw_parts(in_, in_len) }; + + let callback = |method: &M, ssl: &mut _, output: &mut _| method.decrypt(ssl, input, output); + + // SAFETY: boring provides valid inputs. + unsafe { raw_private_key_callback(ssl, out, out_len, max_out, callback) } +} + +pub(super) unsafe extern "C" fn raw_complete( + ssl: *mut ffi::SSL, + out: *mut u8, + out_len: *mut usize, + max_out: usize, +) -> ffi::ssl_private_key_result_t +where + M: PrivateKeyMethod, +{ + // SAFETY: boring provides valid inputs. + unsafe { raw_private_key_callback::(ssl, out, out_len, max_out, M::complete) } +} + +unsafe fn raw_private_key_callback( + ssl: *mut ffi::SSL, + out: *mut u8, + out_len: *mut usize, + max_out: usize, + callback: impl FnOnce(&M, &mut SslRef, &mut [u8]) -> Result, +) -> ffi::ssl_private_key_result_t +where + M: PrivateKeyMethod, +{ + // SAFETY: boring provides valid inputs. + let ssl = unsafe { SslRef::from_ptr_mut(ssl) }; + let output = unsafe { slice::from_raw_parts_mut(out, max_out) }; + let out_len = unsafe { &mut *out_len }; + + let ssl_context = ssl.ssl_context().to_owned(); + let method = ssl_context + .ex_data(SslContext::cached_ex_index::()) + .expect("BUG: private key method missing"); + + match callback(method, ssl, output) { + Ok(written) => { + assert!(written <= max_out); + + *out_len = written; + + ffi::ssl_private_key_result_t::ssl_private_key_success + } + Err(err) => err.0, + } +} diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 34925fb8..c0a6c341 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -1391,6 +1391,31 @@ impl SslContextBuilder { } } + /// Configures a custom private key method on the context. + /// + /// See [`PrivateKeyMethod`] for more details. + /// + /// This corresponds to [`SSL_CTX_set_private_key_method`] + /// + /// [`SSL_CTX_set_private_key_method`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_CTX_set_private_key_method + pub fn set_private_key_method(&mut self, method: M) + where + M: PrivateKeyMethod, + { + unsafe { + self.set_ex_data(SslContext::cached_ex_index::(), method); + + ffi::SSL_CTX_set_private_key_method( + self.as_ptr(), + &ffi::SSL_PRIVATE_KEY_METHOD { + sign: Some(callbacks::raw_sign::), + decrypt: Some(callbacks::raw_decrypt::), + complete: Some(callbacks::raw_complete::), + }, + ) + } + } + /// Checks for consistency between the private key and certificate. /// /// This corresponds to [`SSL_CTX_check_private_key`]. @@ -3790,6 +3815,77 @@ bitflags! { } } +/// Describes private key hooks. This is used to off-load signing operations to +/// a custom, potentially asynchronous, backend. Metadata about the key such as +/// the type and size are parsed out of the certificate. +/// +/// Corresponds to [`ssl_private_key_method_st`]. +/// +/// [`ssl_private_key_method_st`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_private_key_method_st +pub trait PrivateKeyMethod: Send + Sync + 'static { + /// Signs the message `input` using the specified signature algorithm. + /// + /// On success, it returns `Ok(written)` where `written` is the number of + /// bytes written into `output`. On failure, it returns + /// `Err(PrivateKeyMethodError::FAILURE)`. If the operation has not completed, + /// it returns `Err(PrivateKeyMethodError::RETRY)`. + /// + /// The caller should arrange for the high-level operation on `ssl` to be + /// retried when the operation is completed. This will result in a call to + /// [`Self::complete`]. + fn sign( + &self, + ssl: &mut SslRef, + input: &[u8], + signature_algorithm: SslSignatureAlgorithm, + output: &mut [u8], + ) -> Result; + + /// Decrypts `input`. + /// + /// On success, it returns `Ok(written)` where `written` is the number of + /// bytes written into `output`. On failure, it returns + /// `Err(PrivateKeyMethodError::FAILURE)`. If the operation has not completed, + /// it returns `Err(PrivateKeyMethodError::RETRY)`. + /// + /// The caller should arrange for the high-level operation on `ssl` to be + /// retried when the operation is completed. This will result in a call to + /// [`Self::complete`]. + /// + /// This method only works with RSA keys and should perform a raw RSA + /// decryption operation with no padding. + // NOTE(nox): What does it mean that it is an error? + fn decrypt( + &self, + ssl: &mut SslRef, + input: &[u8], + output: &mut [u8], + ) -> Result; + + /// Completes a pending operation. + /// + /// On success, it returns `Ok(written)` where `written` is the number of + /// bytes written into `output`. On failure, it returns + /// `Err(PrivateKeyMethodError::FAILURE)`. If the operation has not completed, + /// it returns `Err(PrivateKeyMethodError::RETRY)`. + /// + /// This method may be called arbitrarily many times before completion. + fn complete(&self, ssl: &mut SslRef, output: &mut [u8]) + -> Result; +} + +/// An error returned from a private key method. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct PrivateKeyMethodError(ffi::ssl_private_key_result_t); + +impl PrivateKeyMethodError { + /// A fatal error occured and the handshake should be terminated. + pub const FAILURE: Self = Self(ffi::ssl_private_key_result_t::ssl_private_key_failure); + + /// The operation could not be completed and should be retried later. + pub const RETRY: Self = Self(ffi::ssl_private_key_result_t::ssl_private_key_retry); +} + use crate::ffi::{SSL_CTX_up_ref, SSL_SESSION_get_master_key, SSL_SESSION_up_ref, SSL_is_server}; use crate::ffi::{DTLS_method, TLS_client_method, TLS_method, TLS_server_method}; diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index 5c986199..a68c3dc7 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -34,6 +34,7 @@ use crate::x509::store::X509StoreBuilder; use crate::x509::verify::X509CheckFlags; use crate::x509::{X509Name, X509StoreContext, X509VerifyResult, X509}; +mod private_key_method; mod server; static ROOT_CERT: &[u8] = include_bytes!("../../../test/root-ca.pem"); @@ -55,9 +56,7 @@ fn verify_untrusted() { #[test] fn verify_trusted() { let server = Server::builder().build(); - - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); + let client = server.client_with_root_ca(); client.connect(); } @@ -109,9 +108,8 @@ fn verify_untrusted_callback_override_bad() { #[test] fn verify_trusted_callback_override_ok() { let server = Server::builder().build(); + let mut client = server.client_with_root_ca(); - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); client .ctx() .set_verify_callback(SslVerifyMode::PEER, |_, x509| { @@ -125,11 +123,12 @@ fn verify_trusted_callback_override_ok() { #[test] fn verify_trusted_callback_override_bad() { let mut server = Server::builder(); - server.should_error(); - let server = server.build(); - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); + server.should_error(); + + let server = server.build(); + let mut client = server.client_with_root_ca(); + client .ctx() .set_verify_callback(SslVerifyMode::PEER, |_, _| false); @@ -155,9 +154,8 @@ fn verify_callback_load_certs() { #[test] fn verify_trusted_get_error_ok() { let server = Server::builder().build(); + let mut client = server.client_with_root_ca(); - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); client .ctx() .set_verify_callback(SslVerifyMode::PEER, |_, x509| { @@ -697,9 +695,8 @@ fn add_extra_chain_cert() { #[test] fn verify_valid_hostname() { let server = Server::builder().build(); + let mut client = server.client_with_root_ca(); - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); client.ctx().set_verify(SslVerifyMode::PEER); let mut client = client.build().builder(); @@ -714,11 +711,12 @@ fn verify_valid_hostname() { #[test] fn verify_invalid_hostname() { let mut server = Server::builder(); - server.should_error(); - let server = server.build(); - let mut client = server.client(); - client.ctx().set_ca_file("test/root-ca.pem").unwrap(); + server.should_error(); + + let server = server.build(); + let mut client = server.client_with_root_ca(); + client.ctx().set_verify(SslVerifyMode::PEER); let mut client = client.build().builder(); diff --git a/boring/src/ssl/test/private_key_method.rs b/boring/src/ssl/test/private_key_method.rs new file mode 100644 index 00000000..f711fccc --- /dev/null +++ b/boring/src/ssl/test/private_key_method.rs @@ -0,0 +1,282 @@ +use once_cell::sync::OnceCell; + +use super::server::{Builder, Server}; +use super::KEY; +use crate::hash::{Hasher, MessageDigest}; +use crate::pkey::PKey; +use crate::rsa::Padding; +use crate::sign::{RsaPssSaltlen, Signer}; +use crate::ssl::{ + ErrorCode, HandshakeError, PrivateKeyMethod, PrivateKeyMethodError, SslRef, + SslSignatureAlgorithm, +}; +use crate::x509::X509; +use std::cmp; +use std::io::{Read, Write}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; + +#[allow(clippy::type_complexity)] +pub(super) struct Method { + sign: Box< + dyn Fn( + &mut SslRef, + &[u8], + SslSignatureAlgorithm, + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + >, + decrypt: Box< + dyn Fn(&mut SslRef, &[u8], &mut [u8]) -> Result + + Send + + Sync + + 'static, + >, + complete: Box< + dyn Fn(&mut SslRef, &mut [u8]) -> Result + + Send + + Sync + + 'static, + >, +} + +impl Method { + pub(super) fn new() -> Self { + Self { + sign: Box::new(|_, _, _, _| unreachable!("called sign")), + decrypt: Box::new(|_, _, _| unreachable!("called decrypt")), + complete: Box::new(|_, _| unreachable!("called complete")), + } + } + + pub(super) fn sign( + mut self, + sign: impl Fn( + &mut SslRef, + &[u8], + SslSignatureAlgorithm, + &mut [u8], + ) -> Result + + Send + + Sync + + 'static, + ) -> Self { + self.sign = Box::new(sign); + + self + } + + #[allow(dead_code)] + pub(super) fn decrypt( + mut self, + decrypt: impl Fn(&mut SslRef, &[u8], &mut [u8]) -> Result + + Send + + Sync + + 'static, + ) -> Self { + self.decrypt = Box::new(decrypt); + + self + } + + pub(super) fn complete( + mut self, + complete: impl Fn(&mut SslRef, &mut [u8]) -> Result + + Send + + Sync + + 'static, + ) -> Self { + self.complete = Box::new(complete); + + self + } +} + +impl PrivateKeyMethod for Method { + fn sign( + &self, + ssl: &mut SslRef, + input: &[u8], + signature_algorithm: SslSignatureAlgorithm, + output: &mut [u8], + ) -> Result { + (self.sign)(ssl, input, signature_algorithm, output) + } + + fn decrypt( + &self, + ssl: &mut SslRef, + input: &[u8], + output: &mut [u8], + ) -> Result { + (self.decrypt)(ssl, input, output) + } + + fn complete( + &self, + ssl: &mut SslRef, + output: &mut [u8], + ) -> Result { + (self.complete)(ssl, output) + } +} + +fn builder_with_private_key_method(method: Method) -> Builder { + let mut builder = Server::builder(); + + builder.ctx().set_private_key_method(method); + + builder +} + +#[test] +fn test_sign_failure() { + let called_sign = Arc::new(AtomicBool::new(false)); + let called_sign_clone = called_sign.clone(); + + let mut builder = builder_with_private_key_method(Method::new().sign(move |_, _, _, _| { + called_sign_clone.store(true, Ordering::SeqCst); + + Err(PrivateKeyMethodError::FAILURE) + })); + + builder.err_cb(|error| { + let HandshakeError::Failure(mid_handshake) = error else { + panic!("should be Failure"); + }; + + assert_eq!(mid_handshake.error().code(), ErrorCode::SSL); + }); + + let server = builder.build(); + let client = server.client_with_root_ca(); + + client.connect_err(); + + assert!(called_sign.load(Ordering::SeqCst)); +} + +#[test] +fn test_sign_retry_complete_failure() { + let called_complete = Arc::new(AtomicUsize::new(0)); + let called_complete_clone = called_complete.clone(); + + let mut builder = builder_with_private_key_method( + Method::new() + .sign(|_, _, _, _| Err(PrivateKeyMethodError::RETRY)) + .complete(move |_, _| { + let old = called_complete_clone.fetch_add(1, Ordering::SeqCst); + + Err(if old == 0 { + PrivateKeyMethodError::RETRY + } else { + PrivateKeyMethodError::FAILURE + }) + }), + ); + + builder.err_cb(|error| { + let HandshakeError::WouldBlock(mid_handshake) = error else { + panic!("should be WouldBlock"); + }; + + assert!(mid_handshake.error().would_block()); + assert_eq!( + mid_handshake.error().code(), + ErrorCode::WANT_PRIVATE_KEY_OPERATION + ); + + let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err() else { + panic!("should be WouldBlock"); + }; + + assert_eq!( + mid_handshake.error().code(), + ErrorCode::WANT_PRIVATE_KEY_OPERATION + ); + + let HandshakeError::Failure(mid_handshake) = mid_handshake.handshake().unwrap_err() else { + panic!("should be Failure"); + }; + + assert_eq!(mid_handshake.error().code(), ErrorCode::SSL); + }); + + let server = builder.build(); + let client = server.client_with_root_ca(); + + client.connect_err(); + + assert_eq!(called_complete.load(Ordering::SeqCst), 2); +} + +#[test] +fn test_sign_ok() { + let server = builder_with_private_key_method(Method::new().sign( + |_, input, signature_algorithm, output| { + assert_eq!( + signature_algorithm, + SslSignatureAlgorithm::RSA_PSS_RSAE_SHA256, + ); + + Ok(sign_with_default_config(input, output)) + }, + )) + .build(); + + let client = server.client_with_root_ca(); + + client.connect(); +} + +#[test] +fn test_sign_retry_complete_ok() { + let input_cell = Arc::new(OnceCell::new()); + let input_cell_clone = input_cell.clone(); + + let mut builder = builder_with_private_key_method( + Method::new() + .sign(move |_, input, _, _| { + input_cell.set(input.to_owned()).unwrap(); + + Err(PrivateKeyMethodError::RETRY) + }) + .complete(move |_, output| { + let input = input_cell_clone.get().unwrap(); + + Ok(sign_with_default_config(input, output)) + }), + ); + + builder.err_cb(|error| { + let HandshakeError::WouldBlock(mid_handshake) = error else { + panic!("should be WouldBlock"); + }; + + let mut socket = mid_handshake.handshake().unwrap(); + + socket.write_all(&[0]).unwrap(); + }); + + let server = builder.build(); + let client = server.client_with_root_ca(); + + client.connect(); +} + +fn sign_with_default_config(input: &[u8], output: &mut [u8]) -> usize { + let pkey = PKey::private_key_from_pem(KEY).unwrap(); + let mut signer = Signer::new(MessageDigest::sha256(), &pkey).unwrap(); + + signer.set_rsa_padding(Padding::PKCS1_PSS).unwrap(); + signer + .set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH) + .unwrap(); + + signer.update(input).unwrap(); + + signer.sign(output).unwrap() +} diff --git a/boring/src/ssl/test/server.rs b/boring/src/ssl/test/server.rs index 41677e57..7d79cd75 100644 --- a/boring/src/ssl/test/server.rs +++ b/boring/src/ssl/test/server.rs @@ -2,7 +2,10 @@ use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::thread::{self, JoinHandle}; -use crate::ssl::{Ssl, SslContext, SslContextBuilder, SslFiletype, SslMethod, SslRef, SslStream}; +use crate::ssl::{ + HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype, + SslMethod, SslRef, SslStream, +}; pub struct Server { handle: Option>, @@ -28,6 +31,7 @@ impl Server { ctx, ssl_cb: Box::new(|_| {}), io_cb: Box::new(|_| {}), + err_cb: Box::new(|_| {}), should_error: false, } } @@ -39,6 +43,14 @@ impl Server { } } + pub fn client_with_root_ca(&self) -> ClientBuilder { + let mut client = self.client(); + + client.ctx().set_ca_file("test/root-ca.pem").unwrap(); + + client + } + pub fn connect_tcp(&self) -> TcpStream { TcpStream::connect(self.addr).unwrap() } @@ -48,6 +60,7 @@ pub struct Builder { ctx: SslContextBuilder, ssl_cb: Box, io_cb: Box) + Send>, + err_cb: Box) + Send>, should_error: bool, } @@ -70,6 +83,12 @@ impl Builder { self.io_cb = Box::new(cb); } + pub fn err_cb(&mut self, cb: impl FnMut(HandshakeError) + Send + 'static) { + self.should_error(); + + self.err_cb = Box::new(cb); + } + pub fn should_error(&mut self) { self.should_error = true; } @@ -80,6 +99,7 @@ impl Builder { let addr = socket.local_addr().unwrap(); let mut ssl_cb = self.ssl_cb; let mut io_cb = self.io_cb; + let mut err_cb = self.err_cb; let should_error = self.should_error; let handle = thread::spawn(move || { @@ -88,7 +108,7 @@ impl Builder { ssl_cb(&mut ssl); let r = ssl.accept(socket); if should_error { - r.unwrap_err(); + err_cb(r.unwrap_err()); } else { let mut socket = r.unwrap(); socket.write_all(&[0]).unwrap(); @@ -124,8 +144,8 @@ impl ClientBuilder { self.build().builder().connect() } - pub fn connect_err(self) { - self.build().builder().connect_err(); + pub fn connect_err(self) -> HandshakeError { + self.build().builder().connect_err() } } @@ -160,8 +180,9 @@ impl ClientSslBuilder { s } - pub fn connect_err(self) { + pub fn connect_err(self) -> HandshakeError { let socket = TcpStream::connect(self.addr).unwrap(); - self.ssl.connect(socket).unwrap_err(); + + self.ssl.setup_connect(socket).handshake().unwrap_err() } }