boring2/boring/src/ssl/test/private_key_method.rs

282 lines
7.2 KiB
Rust

use once_cell::sync::OnceCell;
use super::server::{Builder, Server};
use super::KEY;
use crate::hash::MessageDigest;
use crate::pkey::PKey;
use crate::rsa::Padding;
use crate::sign::{RsaPssSaltlen, Signer};
use crate::ssl::{
ErrorCode, HandshakeError, PrivateKeyMethod, PrivateKeyMethodError, SslRef,
SslSignatureAlgorithm,
};
use std::io::Write;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
#[allow(clippy::type_complexity)]
pub(super) struct Method {
sign: Box<
dyn Fn(
&mut SslRef,
&[u8],
SslSignatureAlgorithm,
&mut [u8],
) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
>,
decrypt: Box<
dyn Fn(&mut SslRef, &[u8], &mut [u8]) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
>,
complete: Box<
dyn Fn(&mut SslRef, &mut [u8]) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
>,
}
impl Method {
pub(super) fn new() -> Self {
Self {
sign: Box::new(|_, _, _, _| unreachable!("called sign")),
decrypt: Box::new(|_, _, _| unreachable!("called decrypt")),
complete: Box::new(|_, _| unreachable!("called complete")),
}
}
pub(super) fn sign(
mut self,
sign: impl Fn(
&mut SslRef,
&[u8],
SslSignatureAlgorithm,
&mut [u8],
) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
) -> Self {
self.sign = Box::new(sign);
self
}
#[allow(dead_code)]
pub(super) fn decrypt(
mut self,
decrypt: impl Fn(&mut SslRef, &[u8], &mut [u8]) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
) -> Self {
self.decrypt = Box::new(decrypt);
self
}
pub(super) fn complete(
mut self,
complete: impl Fn(&mut SslRef, &mut [u8]) -> Result<usize, PrivateKeyMethodError>
+ Send
+ Sync
+ 'static,
) -> Self {
self.complete = Box::new(complete);
self
}
}
impl PrivateKeyMethod for Method {
fn sign(
&self,
ssl: &mut SslRef,
input: &[u8],
signature_algorithm: SslSignatureAlgorithm,
output: &mut [u8],
) -> Result<usize, PrivateKeyMethodError> {
(self.sign)(ssl, input, signature_algorithm, output)
}
fn decrypt(
&self,
ssl: &mut SslRef,
input: &[u8],
output: &mut [u8],
) -> Result<usize, PrivateKeyMethodError> {
(self.decrypt)(ssl, input, output)
}
fn complete(
&self,
ssl: &mut SslRef,
output: &mut [u8],
) -> Result<usize, PrivateKeyMethodError> {
(self.complete)(ssl, output)
}
}
fn builder_with_private_key_method(method: Method) -> Builder {
let mut builder = Server::builder();
builder.ctx().set_private_key_method(method);
builder
}
#[test]
fn test_sign_failure() {
let called_sign = Arc::new(AtomicBool::new(false));
let called_sign_clone = called_sign.clone();
let mut builder = builder_with_private_key_method(Method::new().sign(move |_, _, _, _| {
called_sign_clone.store(true, Ordering::SeqCst);
Err(PrivateKeyMethodError::FAILURE)
}));
builder.err_cb(|error| {
let HandshakeError::Failure(mid_handshake) = error else {
panic!("should be Failure");
};
assert_eq!(mid_handshake.error().code(), ErrorCode::SSL);
});
let server = builder.build();
let client = server.client_with_root_ca();
client.connect_err();
assert!(called_sign.load(Ordering::SeqCst));
}
#[test]
fn test_sign_retry_complete_failure() {
let called_complete = Arc::new(AtomicUsize::new(0));
let called_complete_clone = called_complete.clone();
let mut builder = builder_with_private_key_method(
Method::new()
.sign(|_, _, _, _| Err(PrivateKeyMethodError::RETRY))
.complete(move |_, _| {
let old = called_complete_clone.fetch_add(1, Ordering::SeqCst);
Err(if old == 0 {
PrivateKeyMethodError::RETRY
} else {
PrivateKeyMethodError::FAILURE
})
}),
);
builder.err_cb(|error| {
let HandshakeError::WouldBlock(mid_handshake) = error else {
panic!("should be WouldBlock");
};
assert!(mid_handshake.error().would_block());
assert_eq!(
mid_handshake.error().code(),
ErrorCode::WANT_PRIVATE_KEY_OPERATION
);
let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err()
else {
panic!("should be WouldBlock");
};
assert_eq!(
mid_handshake.error().code(),
ErrorCode::WANT_PRIVATE_KEY_OPERATION
);
let HandshakeError::Failure(mid_handshake) = mid_handshake.handshake().unwrap_err() else {
panic!("should be Failure");
};
assert_eq!(mid_handshake.error().code(), ErrorCode::SSL);
});
let server = builder.build();
let client = server.client_with_root_ca();
client.connect_err();
assert_eq!(called_complete.load(Ordering::SeqCst), 2);
}
#[test]
fn test_sign_ok() {
let server = builder_with_private_key_method(Method::new().sign(
|_, input, signature_algorithm, output| {
assert_eq!(
signature_algorithm,
SslSignatureAlgorithm::RSA_PSS_RSAE_SHA256,
);
Ok(sign_with_default_config(input, output))
},
))
.build();
let client = server.client_with_root_ca();
client.connect();
}
#[test]
fn test_sign_retry_complete_ok() {
let input_cell = Arc::new(OnceCell::new());
let input_cell_clone = input_cell.clone();
let mut builder = builder_with_private_key_method(
Method::new()
.sign(move |_, input, _, _| {
input_cell.set(input.to_owned()).unwrap();
Err(PrivateKeyMethodError::RETRY)
})
.complete(move |_, output| {
let input = input_cell_clone.get().unwrap();
Ok(sign_with_default_config(input, output))
}),
);
builder.err_cb(|error| {
let HandshakeError::WouldBlock(mid_handshake) = error else {
panic!("should be WouldBlock");
};
let mut socket = mid_handshake.handshake().unwrap();
socket.write_all(&[0]).unwrap();
});
let server = builder.build();
let client = server.client_with_root_ca();
client.connect();
}
fn sign_with_default_config(input: &[u8], output: &mut [u8]) -> usize {
let pkey = PKey::private_key_from_pem(KEY).unwrap();
let mut signer = Signer::new(MessageDigest::sha256(), &pkey).unwrap();
signer.set_rsa_padding(Padding::PKCS1_PSS).unwrap();
signer
.set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH)
.unwrap();
signer.update(input).unwrap();
signer.sign(output).unwrap()
}