Implement SslContextBuilder::set_private_key_method

This commit is contained in:
Anthony Ramine 2023-08-02 10:36:26 +02:00 committed by Alessandro Ghedini
parent 61bfbb5bd6
commit 0ffbdb030f
5 changed files with 518 additions and 36 deletions

View File

@ -1,6 +1,13 @@
#![forbid(unsafe_op_in_unsafe_fn)]
use super::{
AlpnError, ClientHello, PrivateKeyMethod, PrivateKeyMethodError, SelectCertError, SniError,
Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef,
SslSignatureAlgorithm, SESSION_CTX_INDEX,
};
use crate::error::ErrorStack;
use crate::ffi;
use crate::x509::{X509StoreContext, X509StoreContextRef};
use foreign_types::ForeignType;
use foreign_types::ForeignTypeRef;
use libc::c_char;
@ -12,19 +19,7 @@ use std::slice;
use std::str;
use std::sync::Arc;
use crate::error::ErrorStack;
use crate::ssl::AlpnError;
use crate::ssl::{ClientHello, SelectCertError};
use crate::ssl::{
SniError, Ssl, SslAlert, SslContext, SslContextRef, SslRef, SslSession, SslSessionRef,
SESSION_CTX_INDEX,
};
use crate::x509::{X509StoreContext, X509StoreContextRef};
pub(super) unsafe extern "C" fn raw_verify<F>(
preverify_ok: c_int,
x509_ctx: *mut ffi::X509_STORE_CTX,
) -> c_int
pub extern "C" fn raw_verify<F>(preverify_ok: c_int, x509_ctx: *mut ffi::X509_STORE_CTX) -> c_int
where
F: Fn(bool, &mut X509StoreContextRef) -> bool + 'static + Sync + Send,
{
@ -372,3 +367,93 @@ where
callback(ssl, line);
}
pub(super) unsafe extern "C" fn raw_sign<M>(
ssl: *mut ffi::SSL,
out: *mut u8,
out_len: *mut usize,
max_out: usize,
signature_algorithm: u16,
in_: *const u8,
in_len: usize,
) -> ffi::ssl_private_key_result_t
where
M: PrivateKeyMethod,
{
// SAFETY: boring provides valid inputs.
let input = unsafe { slice::from_raw_parts(in_, in_len) };
let signature_algorithm = SslSignatureAlgorithm(signature_algorithm);
let callback = |method: &M, ssl: &mut _, output: &mut _| {
method.sign(ssl, input, signature_algorithm, output)
};
// SAFETY: boring provides valid inputs.
unsafe { raw_private_key_callback(ssl, out, out_len, max_out, callback) }
}
pub(super) unsafe extern "C" fn raw_decrypt<M>(
ssl: *mut ffi::SSL,
out: *mut u8,
out_len: *mut usize,
max_out: usize,
in_: *const u8,
in_len: usize,
) -> ffi::ssl_private_key_result_t
where
M: PrivateKeyMethod,
{
// SAFETY: boring provides valid inputs.
let input = unsafe { slice::from_raw_parts(in_, in_len) };
let callback = |method: &M, ssl: &mut _, output: &mut _| method.decrypt(ssl, input, output);
// SAFETY: boring provides valid inputs.
unsafe { raw_private_key_callback(ssl, out, out_len, max_out, callback) }
}
pub(super) unsafe extern "C" fn raw_complete<M>(
ssl: *mut ffi::SSL,
out: *mut u8,
out_len: *mut usize,
max_out: usize,
) -> ffi::ssl_private_key_result_t
where
M: PrivateKeyMethod,
{
// SAFETY: boring provides valid inputs.
unsafe { raw_private_key_callback::<M>(ssl, out, out_len, max_out, M::complete) }
}
unsafe fn raw_private_key_callback<M>(
ssl: *mut ffi::SSL,
out: *mut u8,
out_len: *mut usize,
max_out: usize,
callback: impl FnOnce(&M, &mut SslRef, &mut [u8]) -> Result<usize, PrivateKeyMethodError>,
) -> ffi::ssl_private_key_result_t
where
M: PrivateKeyMethod,
{
// SAFETY: boring provides valid inputs.
let ssl = unsafe { SslRef::from_ptr_mut(ssl) };
let output = unsafe { slice::from_raw_parts_mut(out, max_out) };
let out_len = unsafe { &mut *out_len };
let ssl_context = ssl.ssl_context().to_owned();
let method = ssl_context
.ex_data(SslContext::cached_ex_index::<M>())
.expect("BUG: private key method missing");
match callback(method, ssl, output) {
Ok(written) => {
assert!(written <= max_out);
*out_len = written;
ffi::ssl_private_key_result_t::ssl_private_key_success
}
Err(err) => err.0,
}
}

View File

@ -1391,6 +1391,31 @@ impl SslContextBuilder {
}
}
/// Configures a custom private key method on the context.
///
/// See [`PrivateKeyMethod`] for more details.
///
/// This corresponds to [`SSL_CTX_set_private_key_method`]
///
/// [`SSL_CTX_set_private_key_method`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_CTX_set_private_key_method
pub fn set_private_key_method<M>(&mut self, method: M)
where
M: PrivateKeyMethod,
{
unsafe {
self.set_ex_data(SslContext::cached_ex_index::<M>(), method);
ffi::SSL_CTX_set_private_key_method(
self.as_ptr(),
&ffi::SSL_PRIVATE_KEY_METHOD {
sign: Some(callbacks::raw_sign::<M>),
decrypt: Some(callbacks::raw_decrypt::<M>),
complete: Some(callbacks::raw_complete::<M>),
},
)
}
}
/// Checks for consistency between the private key and certificate.
///
/// This corresponds to [`SSL_CTX_check_private_key`].
@ -3790,6 +3815,77 @@ bitflags! {
}
}
/// Describes private key hooks. This is used to off-load signing operations to
/// a custom, potentially asynchronous, backend. Metadata about the key such as
/// the type and size are parsed out of the certificate.
///
/// Corresponds to [`ssl_private_key_method_st`].
///
/// [`ssl_private_key_method_st`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_private_key_method_st
pub trait PrivateKeyMethod: Send + Sync + 'static {
/// Signs the message `input` using the specified signature algorithm.
///
/// On success, it returns `Ok(written)` where `written` is the number of
/// bytes written into `output`. On failure, it returns
/// `Err(PrivateKeyMethodError::FAILURE)`. If the operation has not completed,
/// it returns `Err(PrivateKeyMethodError::RETRY)`.
///
/// The caller should arrange for the high-level operation on `ssl` to be
/// retried when the operation is completed. This will result in a call to
/// [`Self::complete`].
fn sign(
&self,
ssl: &mut SslRef,
input: &[u8],
signature_algorithm: SslSignatureAlgorithm,
output: &mut [u8],
) -> Result<usize, PrivateKeyMethodError>;
/// Decrypts `input`.
///
/// On success, it returns `Ok(written)` where `written` is the number of
/// bytes written into `output`. On failure, it returns
/// `Err(PrivateKeyMethodError::FAILURE)`. If the operation has not completed,
/// it returns `Err(PrivateKeyMethodError::RETRY)`.
///
/// The caller should arrange for the high-level operation on `ssl` to be
/// retried when the operation is completed. This will result in a call to
/// [`Self::complete`].
///
/// This method only works with RSA keys and should perform a raw RSA
/// decryption operation with no padding.
// NOTE(nox): What does it mean that it is an error?
fn decrypt(
&self,
ssl: &mut SslRef,
input: &[u8],
output: &mut [u8],
) -> Result<usize, PrivateKeyMethodError>;
/// Completes a pending operation.
///
/// On success, it returns `Ok(written)` where `written` is the number of
/// bytes written into `output`. On failure, it returns
/// `Err(PrivateKeyMethodError::FAILURE)`. If the operation has not completed,
/// it returns `Err(PrivateKeyMethodError::RETRY)`.
///
/// This method may be called arbitrarily many times before completion.
fn complete(&self, ssl: &mut SslRef, output: &mut [u8])
-> Result<usize, PrivateKeyMethodError>;
}
/// An error returned from a private key method.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct PrivateKeyMethodError(ffi::ssl_private_key_result_t);
impl PrivateKeyMethodError {
/// A fatal error occured and the handshake should be terminated.
pub const FAILURE: Self = Self(ffi::ssl_private_key_result_t::ssl_private_key_failure);
/// The operation could not be completed and should be retried later.
pub const RETRY: Self = Self(ffi::ssl_private_key_result_t::ssl_private_key_retry);
}
use crate::ffi::{SSL_CTX_up_ref, SSL_SESSION_get_master_key, SSL_SESSION_up_ref, SSL_is_server};
use crate::ffi::{DTLS_method, TLS_client_method, TLS_method, TLS_server_method};

View File

@ -34,6 +34,7 @@ use crate::x509::store::X509StoreBuilder;
use crate::x509::verify::X509CheckFlags;
use crate::x509::{X509Name, X509StoreContext, X509VerifyResult, X509};
mod private_key_method;
mod server;
static ROOT_CERT: &[u8] = include_bytes!("../../../test/root-ca.pem");
@ -55,9 +56,7 @@ fn verify_untrusted() {
#[test]
fn verify_trusted() {
let server = Server::builder().build();
let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
let client = server.client_with_root_ca();
client.connect();
}
@ -109,9 +108,8 @@ fn verify_untrusted_callback_override_bad() {
#[test]
fn verify_trusted_callback_override_ok() {
let server = Server::builder().build();
let mut client = server.client_with_root_ca();
let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
client
.ctx()
.set_verify_callback(SslVerifyMode::PEER, |_, x509| {
@ -125,11 +123,12 @@ fn verify_trusted_callback_override_ok() {
#[test]
fn verify_trusted_callback_override_bad() {
let mut server = Server::builder();
server.should_error();
let server = server.build();
let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
server.should_error();
let server = server.build();
let mut client = server.client_with_root_ca();
client
.ctx()
.set_verify_callback(SslVerifyMode::PEER, |_, _| false);
@ -155,9 +154,8 @@ fn verify_callback_load_certs() {
#[test]
fn verify_trusted_get_error_ok() {
let server = Server::builder().build();
let mut client = server.client_with_root_ca();
let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
client
.ctx()
.set_verify_callback(SslVerifyMode::PEER, |_, x509| {
@ -697,9 +695,8 @@ fn add_extra_chain_cert() {
#[test]
fn verify_valid_hostname() {
let server = Server::builder().build();
let mut client = server.client_with_root_ca();
let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
client.ctx().set_verify(SslVerifyMode::PEER);
let mut client = client.build().builder();
@ -714,11 +711,12 @@ fn verify_valid_hostname() {
#[test]
fn verify_invalid_hostname() {
let mut server = Server::builder();
server.should_error();
let server = server.build();
let mut client = server.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
server.should_error();
let server = server.build();
let mut client = server.client_with_root_ca();
client.ctx().set_verify(SslVerifyMode::PEER);
let mut client = client.build().builder();

View File

@ -0,0 +1,282 @@
use once_cell::sync::OnceCell;
use super::server::{Builder, Server};
use super::KEY;
use crate::hash::{Hasher, MessageDigest};
use crate::pkey::PKey;
use crate::rsa::Padding;
use crate::sign::{RsaPssSaltlen, Signer};
use crate::ssl::{
ErrorCode, HandshakeError, PrivateKeyMethod, PrivateKeyMethodError, SslRef,
SslSignatureAlgorithm,
};
use crate::x509::X509;
use std::cmp;
use std::io::{Read, Write};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
#[allow(clippy::type_complexity)]
pub(super) struct Method {
sign: Box<
dyn Fn(
&mut SslRef,
&[u8],
SslSignatureAlgorithm,
&mut [u8],
) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
>,
decrypt: Box<
dyn Fn(&mut SslRef, &[u8], &mut [u8]) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
>,
complete: Box<
dyn Fn(&mut SslRef, &mut [u8]) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
>,
}
impl Method {
pub(super) fn new() -> Self {
Self {
sign: Box::new(|_, _, _, _| unreachable!("called sign")),
decrypt: Box::new(|_, _, _| unreachable!("called decrypt")),
complete: Box::new(|_, _| unreachable!("called complete")),
}
}
pub(super) fn sign(
mut self,
sign: impl Fn(
&mut SslRef,
&[u8],
SslSignatureAlgorithm,
&mut [u8],
) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
) -> Self {
self.sign = Box::new(sign);
self
}
#[allow(dead_code)]
pub(super) fn decrypt(
mut self,
decrypt: impl Fn(&mut SslRef, &[u8], &mut [u8]) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
) -> Self {
self.decrypt = Box::new(decrypt);
self
}
pub(super) fn complete(
mut self,
complete: impl Fn(&mut SslRef, &mut [u8]) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
) -> Self {
self.complete = Box::new(complete);
self
}
}
impl PrivateKeyMethod for Method {
fn sign(
&self,
ssl: &mut SslRef,
input: &[u8],
signature_algorithm: SslSignatureAlgorithm,
output: &mut [u8],
) -> Result<usize, PrivateKeyMethodError> {
(self.sign)(ssl, input, signature_algorithm, output)
}
fn decrypt(
&self,
ssl: &mut SslRef,
input: &[u8],
output: &mut [u8],
) -> Result<usize, PrivateKeyMethodError> {
(self.decrypt)(ssl, input, output)
}
fn complete(
&self,
ssl: &mut SslRef,
output: &mut [u8],
) -> Result<usize, PrivateKeyMethodError> {
(self.complete)(ssl, output)
}
}
fn builder_with_private_key_method(method: Method) -> Builder {
let mut builder = Server::builder();
builder.ctx().set_private_key_method(method);
builder
}
#[test]
fn test_sign_failure() {
let called_sign = Arc::new(AtomicBool::new(false));
let called_sign_clone = called_sign.clone();
let mut builder = builder_with_private_key_method(Method::new().sign(move |_, _, _, _| {
called_sign_clone.store(true, Ordering::SeqCst);
Err(PrivateKeyMethodError::FAILURE)
}));
builder.err_cb(|error| {
let HandshakeError::Failure(mid_handshake) = error else {
panic!("should be Failure");
};
assert_eq!(mid_handshake.error().code(), ErrorCode::SSL);
});
let server = builder.build();
let client = server.client_with_root_ca();
client.connect_err();
assert!(called_sign.load(Ordering::SeqCst));
}
#[test]
fn test_sign_retry_complete_failure() {
let called_complete = Arc::new(AtomicUsize::new(0));
let called_complete_clone = called_complete.clone();
let mut builder = builder_with_private_key_method(
Method::new()
.sign(|_, _, _, _| Err(PrivateKeyMethodError::RETRY))
.complete(move |_, _| {
let old = called_complete_clone.fetch_add(1, Ordering::SeqCst);
Err(if old == 0 {
PrivateKeyMethodError::RETRY
} else {
PrivateKeyMethodError::FAILURE
})
}),
);
builder.err_cb(|error| {
let HandshakeError::WouldBlock(mid_handshake) = error else {
panic!("should be WouldBlock");
};
assert!(mid_handshake.error().would_block());
assert_eq!(
mid_handshake.error().code(),
ErrorCode::WANT_PRIVATE_KEY_OPERATION
);
let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err() else {
panic!("should be WouldBlock");
};
assert_eq!(
mid_handshake.error().code(),
ErrorCode::WANT_PRIVATE_KEY_OPERATION
);
let HandshakeError::Failure(mid_handshake) = mid_handshake.handshake().unwrap_err() else {
panic!("should be Failure");
};
assert_eq!(mid_handshake.error().code(), ErrorCode::SSL);
});
let server = builder.build();
let client = server.client_with_root_ca();
client.connect_err();
assert_eq!(called_complete.load(Ordering::SeqCst), 2);
}
#[test]
fn test_sign_ok() {
let server = builder_with_private_key_method(Method::new().sign(
|_, input, signature_algorithm, output| {
assert_eq!(
signature_algorithm,
SslSignatureAlgorithm::RSA_PSS_RSAE_SHA256,
);
Ok(sign_with_default_config(input, output))
},
))
.build();
let client = server.client_with_root_ca();
client.connect();
}
#[test]
fn test_sign_retry_complete_ok() {
let input_cell = Arc::new(OnceCell::new());
let input_cell_clone = input_cell.clone();
let mut builder = builder_with_private_key_method(
Method::new()
.sign(move |_, input, _, _| {
input_cell.set(input.to_owned()).unwrap();
Err(PrivateKeyMethodError::RETRY)
})
.complete(move |_, output| {
let input = input_cell_clone.get().unwrap();
Ok(sign_with_default_config(input, output))
}),
);
builder.err_cb(|error| {
let HandshakeError::WouldBlock(mid_handshake) = error else {
panic!("should be WouldBlock");
};
let mut socket = mid_handshake.handshake().unwrap();
socket.write_all(&[0]).unwrap();
});
let server = builder.build();
let client = server.client_with_root_ca();
client.connect();
}
fn sign_with_default_config(input: &[u8], output: &mut [u8]) -> usize {
let pkey = PKey::private_key_from_pem(KEY).unwrap();
let mut signer = Signer::new(MessageDigest::sha256(), &pkey).unwrap();
signer.set_rsa_padding(Padding::PKCS1_PSS).unwrap();
signer
.set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH)
.unwrap();
signer.update(input).unwrap();
signer.sign(output).unwrap()
}

View File

@ -2,7 +2,10 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::thread::{self, JoinHandle};
use crate::ssl::{Ssl, SslContext, SslContextBuilder, SslFiletype, SslMethod, SslRef, SslStream};
use crate::ssl::{
HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype,
SslMethod, SslRef, SslStream,
};
pub struct Server {
handle: Option<JoinHandle<()>>,
@ -28,6 +31,7 @@ impl Server {
ctx,
ssl_cb: Box::new(|_| {}),
io_cb: Box::new(|_| {}),
err_cb: Box::new(|_| {}),
should_error: false,
}
}
@ -39,6 +43,14 @@ impl Server {
}
}
pub fn client_with_root_ca(&self) -> ClientBuilder {
let mut client = self.client();
client.ctx().set_ca_file("test/root-ca.pem").unwrap();
client
}
pub fn connect_tcp(&self) -> TcpStream {
TcpStream::connect(self.addr).unwrap()
}
@ -48,6 +60,7 @@ pub struct Builder {
ctx: SslContextBuilder,
ssl_cb: Box<dyn FnMut(&mut SslRef) + Send>,
io_cb: Box<dyn FnMut(SslStream<TcpStream>) + Send>,
err_cb: Box<dyn FnMut(HandshakeError<TcpStream>) + Send>,
should_error: bool,
}
@ -70,6 +83,12 @@ impl Builder {
self.io_cb = Box::new(cb);
}
pub fn err_cb(&mut self, cb: impl FnMut(HandshakeError<TcpStream>) + Send + 'static) {
self.should_error();
self.err_cb = Box::new(cb);
}
pub fn should_error(&mut self) {
self.should_error = true;
}
@ -80,6 +99,7 @@ impl Builder {
let addr = socket.local_addr().unwrap();
let mut ssl_cb = self.ssl_cb;
let mut io_cb = self.io_cb;
let mut err_cb = self.err_cb;
let should_error = self.should_error;
let handle = thread::spawn(move || {
@ -88,7 +108,7 @@ impl Builder {
ssl_cb(&mut ssl);
let r = ssl.accept(socket);
if should_error {
r.unwrap_err();
err_cb(r.unwrap_err());
} else {
let mut socket = r.unwrap();
socket.write_all(&[0]).unwrap();
@ -124,8 +144,8 @@ impl ClientBuilder {
self.build().builder().connect()
}
pub fn connect_err(self) {
self.build().builder().connect_err();
pub fn connect_err(self) -> HandshakeError<TcpStream> {
self.build().builder().connect_err()
}
}
@ -160,8 +180,9 @@ impl ClientSslBuilder {
s
}
pub fn connect_err(self) {
pub fn connect_err(self) -> HandshakeError<TcpStream> {
let socket = TcpStream::connect(self.addr).unwrap();
self.ssl.connect(socket).unwrap_err();
self.ssl.setup_connect(socket).handshake().unwrap_err()
}
}