284 lines
7.3 KiB
Rust
284 lines
7.3 KiB
Rust
use once_cell::sync::OnceCell;
|
|
|
|
use super::server::{Builder, Server};
|
|
use super::KEY;
|
|
use crate::hash::{Hasher, MessageDigest};
|
|
use crate::pkey::PKey;
|
|
use crate::rsa::Padding;
|
|
use crate::sign::{RsaPssSaltlen, Signer};
|
|
use crate::ssl::{
|
|
ErrorCode, HandshakeError, PrivateKeyMethod, PrivateKeyMethodError, SslRef,
|
|
SslSignatureAlgorithm,
|
|
};
|
|
use crate::x509::X509;
|
|
use std::cmp;
|
|
use std::io::{Read, Write};
|
|
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
|
use std::sync::Arc;
|
|
|
|
#[allow(clippy::type_complexity)]
|
|
pub(super) struct Method {
|
|
sign: Box<
|
|
dyn Fn(
|
|
&mut SslRef,
|
|
&[u8],
|
|
SslSignatureAlgorithm,
|
|
&mut [u8],
|
|
) -> Result<usize, PrivateKeyMethodError>
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
>,
|
|
decrypt: Box<
|
|
dyn Fn(&mut SslRef, &[u8], &mut [u8]) -> Result<usize, PrivateKeyMethodError>
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
>,
|
|
complete: Box<
|
|
dyn Fn(&mut SslRef, &mut [u8]) -> Result<usize, PrivateKeyMethodError>
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
>,
|
|
}
|
|
|
|
impl Method {
|
|
pub(super) fn new() -> Self {
|
|
Self {
|
|
sign: Box::new(|_, _, _, _| unreachable!("called sign")),
|
|
decrypt: Box::new(|_, _, _| unreachable!("called decrypt")),
|
|
complete: Box::new(|_, _| unreachable!("called complete")),
|
|
}
|
|
}
|
|
|
|
pub(super) fn sign(
|
|
mut self,
|
|
sign: impl Fn(
|
|
&mut SslRef,
|
|
&[u8],
|
|
SslSignatureAlgorithm,
|
|
&mut [u8],
|
|
) -> Result<usize, PrivateKeyMethodError>
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
) -> Self {
|
|
self.sign = Box::new(sign);
|
|
|
|
self
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub(super) fn decrypt(
|
|
mut self,
|
|
decrypt: impl Fn(&mut SslRef, &[u8], &mut [u8]) -> Result<usize, PrivateKeyMethodError>
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
) -> Self {
|
|
self.decrypt = Box::new(decrypt);
|
|
|
|
self
|
|
}
|
|
|
|
pub(super) fn complete(
|
|
mut self,
|
|
complete: impl Fn(&mut SslRef, &mut [u8]) -> Result<usize, PrivateKeyMethodError>
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
) -> Self {
|
|
self.complete = Box::new(complete);
|
|
|
|
self
|
|
}
|
|
}
|
|
|
|
impl PrivateKeyMethod for Method {
|
|
fn sign(
|
|
&self,
|
|
ssl: &mut SslRef,
|
|
input: &[u8],
|
|
signature_algorithm: SslSignatureAlgorithm,
|
|
output: &mut [u8],
|
|
) -> Result<usize, PrivateKeyMethodError> {
|
|
(self.sign)(ssl, input, signature_algorithm, output)
|
|
}
|
|
|
|
fn decrypt(
|
|
&self,
|
|
ssl: &mut SslRef,
|
|
input: &[u8],
|
|
output: &mut [u8],
|
|
) -> Result<usize, PrivateKeyMethodError> {
|
|
(self.decrypt)(ssl, input, output)
|
|
}
|
|
|
|
fn complete(
|
|
&self,
|
|
ssl: &mut SslRef,
|
|
output: &mut [u8],
|
|
) -> Result<usize, PrivateKeyMethodError> {
|
|
(self.complete)(ssl, output)
|
|
}
|
|
}
|
|
|
|
fn builder_with_private_key_method(method: Method) -> Builder {
|
|
let mut builder = Server::builder();
|
|
|
|
builder.ctx().set_private_key_method(method);
|
|
|
|
builder
|
|
}
|
|
|
|
#[test]
|
|
fn test_sign_failure() {
|
|
let called_sign = Arc::new(AtomicBool::new(false));
|
|
let called_sign_clone = called_sign.clone();
|
|
|
|
let mut builder = builder_with_private_key_method(Method::new().sign(move |_, _, _, _| {
|
|
called_sign_clone.store(true, Ordering::SeqCst);
|
|
|
|
Err(PrivateKeyMethodError::FAILURE)
|
|
}));
|
|
|
|
builder.err_cb(|error| {
|
|
let HandshakeError::Failure(mid_handshake) = error else {
|
|
panic!("should be Failure");
|
|
};
|
|
|
|
assert_eq!(mid_handshake.error().code(), ErrorCode::SSL);
|
|
});
|
|
|
|
let server = builder.build();
|
|
let client = server.client_with_root_ca();
|
|
|
|
client.connect_err();
|
|
|
|
assert!(called_sign.load(Ordering::SeqCst));
|
|
}
|
|
|
|
#[test]
|
|
fn test_sign_retry_complete_failure() {
|
|
let called_complete = Arc::new(AtomicUsize::new(0));
|
|
let called_complete_clone = called_complete.clone();
|
|
|
|
let mut builder = builder_with_private_key_method(
|
|
Method::new()
|
|
.sign(|_, _, _, _| Err(PrivateKeyMethodError::RETRY))
|
|
.complete(move |_, _| {
|
|
let old = called_complete_clone.fetch_add(1, Ordering::SeqCst);
|
|
|
|
Err(if old == 0 {
|
|
PrivateKeyMethodError::RETRY
|
|
} else {
|
|
PrivateKeyMethodError::FAILURE
|
|
})
|
|
}),
|
|
);
|
|
|
|
builder.err_cb(|error| {
|
|
let HandshakeError::WouldBlock(mid_handshake) = error else {
|
|
panic!("should be WouldBlock");
|
|
};
|
|
|
|
assert!(mid_handshake.error().would_block());
|
|
assert_eq!(
|
|
mid_handshake.error().code(),
|
|
ErrorCode::WANT_PRIVATE_KEY_OPERATION
|
|
);
|
|
|
|
let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err()
|
|
else {
|
|
panic!("should be WouldBlock");
|
|
};
|
|
|
|
assert_eq!(
|
|
mid_handshake.error().code(),
|
|
ErrorCode::WANT_PRIVATE_KEY_OPERATION
|
|
);
|
|
|
|
let HandshakeError::Failure(mid_handshake) = mid_handshake.handshake().unwrap_err() else {
|
|
panic!("should be Failure");
|
|
};
|
|
|
|
assert_eq!(mid_handshake.error().code(), ErrorCode::SSL);
|
|
});
|
|
|
|
let server = builder.build();
|
|
let client = server.client_with_root_ca();
|
|
|
|
client.connect_err();
|
|
|
|
assert_eq!(called_complete.load(Ordering::SeqCst), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_sign_ok() {
|
|
let server = builder_with_private_key_method(Method::new().sign(
|
|
|_, input, signature_algorithm, output| {
|
|
assert_eq!(
|
|
signature_algorithm,
|
|
SslSignatureAlgorithm::RSA_PSS_RSAE_SHA256,
|
|
);
|
|
|
|
Ok(sign_with_default_config(input, output))
|
|
},
|
|
))
|
|
.build();
|
|
|
|
let client = server.client_with_root_ca();
|
|
|
|
client.connect();
|
|
}
|
|
|
|
#[test]
|
|
fn test_sign_retry_complete_ok() {
|
|
let input_cell = Arc::new(OnceCell::new());
|
|
let input_cell_clone = input_cell.clone();
|
|
|
|
let mut builder = builder_with_private_key_method(
|
|
Method::new()
|
|
.sign(move |_, input, _, _| {
|
|
input_cell.set(input.to_owned()).unwrap();
|
|
|
|
Err(PrivateKeyMethodError::RETRY)
|
|
})
|
|
.complete(move |_, output| {
|
|
let input = input_cell_clone.get().unwrap();
|
|
|
|
Ok(sign_with_default_config(input, output))
|
|
}),
|
|
);
|
|
|
|
builder.err_cb(|error| {
|
|
let HandshakeError::WouldBlock(mid_handshake) = error else {
|
|
panic!("should be WouldBlock");
|
|
};
|
|
|
|
let mut socket = mid_handshake.handshake().unwrap();
|
|
|
|
socket.write_all(&[0]).unwrap();
|
|
});
|
|
|
|
let server = builder.build();
|
|
let client = server.client_with_root_ca();
|
|
|
|
client.connect();
|
|
}
|
|
|
|
fn sign_with_default_config(input: &[u8], output: &mut [u8]) -> usize {
|
|
let pkey = PKey::private_key_from_pem(KEY).unwrap();
|
|
let mut signer = Signer::new(MessageDigest::sha256(), &pkey).unwrap();
|
|
|
|
signer.set_rsa_padding(Padding::PKCS1_PSS).unwrap();
|
|
signer
|
|
.set_rsa_pss_saltlen(RsaPssSaltlen::DIGEST_LENGTH)
|
|
.unwrap();
|
|
|
|
signer.update(input).unwrap();
|
|
|
|
signer.sign(output).unwrap()
|
|
}
|