From 563754fb0892ebf8021bb6043f4540c98f3b86a6 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 12 Nov 2016 12:43:44 +0000 Subject: [PATCH] Add SslContextBuilder::set_tmp_{ec,}dh_callback --- openssl-sys/src/lib.rs | 10 ++++ openssl-sys/src/ossl10x.rs | 10 ++++ openssl/src/dsa.rs | 1 + openssl/src/rsa.rs | 1 + openssl/src/ssl/mod.rs | 91 +++++++++++++++++++++++++++++++++--- openssl/src/ssl/tests/mod.rs | 65 ++++++++++++++++++++++++++ 6 files changed, 172 insertions(+), 6 deletions(-) diff --git a/openssl-sys/src/lib.rs b/openssl-sys/src/lib.rs index 17c2864d..7fda34d4 100644 --- a/openssl-sys/src/lib.rs +++ b/openssl-sys/src/lib.rs @@ -1592,6 +1592,11 @@ extern { #[cfg(not(ossl101))] pub fn SSL_get_privatekey(ssl: *const SSL) -> *mut EVP_PKEY; pub fn SSL_load_client_CA_file(file: *const c_char) -> *mut stack_st_X509_NAME; + pub fn SSL_set_tmp_dh_callback(ctx: *mut SSL, + dh: unsafe extern fn(ssl: *mut SSL, + is_export: c_int, + keylength: c_int) + -> *mut DH); #[cfg(not(osslconf = "OPENSSL_NO_COMP"))] pub fn SSL_COMP_get_name(comp: *const COMP_METHOD) -> *const c_char; @@ -1624,6 +1629,11 @@ extern { pub fn SSL_CTX_check_private_key(ctx: *const SSL_CTX) -> c_int; pub fn SSL_CTX_set_client_CA_list(ctx: *mut SSL_CTX, list: *mut stack_st_X509_NAME); pub fn SSL_CTX_get_cert_store(ctx: *const SSL_CTX) -> *mut X509_STORE; + pub fn SSL_CTX_set_tmp_dh_callback(ctx: *mut SSL_CTX, + dh: unsafe extern fn(ssl: *mut SSL, + is_export: c_int, + keylength: c_int) + -> *mut DH); #[cfg(not(ossl101))] pub fn SSL_CTX_get0_certificate(ctx: *const SSL_CTX) -> *mut X509; diff --git a/openssl-sys/src/ossl10x.rs b/openssl-sys/src/ossl10x.rs index baab16ac..ad7d065c 100644 --- a/openssl-sys/src/ossl10x.rs +++ b/openssl-sys/src/ossl10x.rs @@ -576,12 +576,22 @@ extern { dup_func: Option<::CRYPTO_EX_dup>, free_func: Option<::CRYPTO_EX_free>) -> c_int; + pub fn SSL_set_tmp_ecdh_callback(ssl: *mut ::SSL, + ecdh: unsafe extern fn(ssl: *mut ::SSL, + is_export: c_int, + keylength: c_int) + -> *mut ::EC_KEY); pub fn SSL_CIPHER_get_version(cipher: *const ::SSL_CIPHER) -> *mut c_char; pub fn SSL_CTX_get_ex_new_index(argl: c_long, argp: *mut c_void, new_func: Option<::CRYPTO_EX_new>, dup_func: Option<::CRYPTO_EX_dup>, free_func: Option<::CRYPTO_EX_free>) -> c_int; + pub fn SSL_CTX_set_tmp_ecdh_callback(ctx: *mut ::SSL_CTX, + ecdh: unsafe extern fn(ssl: *mut ::SSL, + is_export: c_int, + keylength: c_int) + -> *mut ::EC_KEY); pub fn X509_get_subject_name(x: *mut ::X509) -> *mut ::X509_NAME; pub fn X509_set_notAfter(x: *mut ::X509, tm: *const ::ASN1_TIME) -> c_int; pub fn X509_set_notBefore(x: *mut ::X509, tm: *const ::ASN1_TIME) -> c_int; diff --git a/openssl/src/dsa.rs b/openssl/src/dsa.rs index 81e551ab..72076775 100644 --- a/openssl/src/dsa.rs +++ b/openssl/src/dsa.rs @@ -57,6 +57,7 @@ impl DsaRef { } } + // FIXME should return u32 pub fn size(&self) -> Option { if self.q().is_some() { unsafe { Some(ffi::DSA_size(self.as_ptr()) as u32) } diff --git a/openssl/src/rsa.rs b/openssl/src/rsa.rs index 70d41997..bd36570d 100644 --- a/openssl/src/rsa.rs +++ b/openssl/src/rsa.rs @@ -70,6 +70,7 @@ impl RsaRef { } } + // FIXME should return u32 pub fn size(&self) -> usize { unsafe { assert!(self.n().is_some()); diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 1e0d2e66..ed3f023c 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -91,8 +91,10 @@ use std::str; use std::sync::Mutex; use {init, cvt, cvt_p}; -use dh::DhRef; +use dh::{Dh, DhRef}; use ec_key::EcKeyRef; +#[cfg(any(all(feature = "v101", ossl101), all(feature = "v102", ossl102)))] +use ec_key::EcKey; use x509::{X509StoreContextRef, X509FileType, X509, X509Ref, X509VerifyError, X509Name}; use x509::store::X509StoreBuilderRef; #[cfg(any(ossl102, ossl110))] @@ -219,7 +221,7 @@ lazy_static! { // Creates a static index for user data of type T // Registers a destructor for the data which will be called // when context is freed -fn get_verify_data_idx() -> c_int { +fn get_callback_idx() -> c_int { *INDEXES.lock().unwrap().entry(TypeId::of::()).or_insert_with(|| get_new_idx::()) } @@ -272,7 +274,7 @@ extern "C" fn raw_verify(preverify_ok: c_int, x509_ctx: *mut ffi::X509_STORE_ let idx = ffi::SSL_get_ex_data_X509_STORE_CTX_idx(); let ssl = ffi::X509_STORE_CTX_get_ex_data(x509_ctx, idx); let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl as *const _); - let verify = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_verify_data_idx::()); + let verify = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_callback_idx::()); let verify: &F = &*(verify as *mut F); let ctx = X509StoreContextRef::from_ptr(x509_ctx); @@ -301,7 +303,7 @@ extern "C" fn raw_sni(ssl: *mut ffi::SSL, al: *mut c_int, _arg: *mut c_void) { unsafe { let ssl_ctx = ffi::SSL_get_SSL_CTX(ssl); - let callback = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_verify_data_idx::()); + let callback = ffi::SSL_CTX_get_ex_data(ssl_ctx, get_callback_idx::()); let callback: &F = &*(callback as *mut F); let ssl = SslRef::from_ptr_mut(ssl); @@ -373,6 +375,55 @@ extern "C" fn raw_alpn_select_cb(ssl: *mut ffi::SSL, unsafe { select_proto_using(ssl, out as *mut _, outlen, inbuf, inlen, *ALPN_PROTOS_IDX) } } +unsafe extern fn raw_tmp_dh(ssl: *mut ffi::SSL, + is_export: c_int, + keylength: c_int) + -> *mut ffi::DH + where F: Fn(&mut SslRef, bool, u32) -> Result + Any + 'static + Sync + Send +{ + let ctx = ffi::SSL_get_SSL_CTX(ssl); + let callback = ffi::SSL_CTX_get_ex_data(ctx, get_callback_idx::()); + let callback = &*(callback as *mut F); + + let ssl = SslRef::from_ptr_mut(ssl); + match callback(ssl, is_export != 0, keylength as u32) { + Ok(dh) => { + let ptr = dh.as_ptr(); + mem::forget(dh); + ptr + } + Err(_) => { + // FIXME reset error stack + ptr::null_mut() + } + } +} + +#[cfg(any(all(feature = "v101", ossl101), all(feature = "v102", ossl102)))] +unsafe extern fn raw_tmp_ecdh(ssl: *mut ffi::SSL, + is_export: c_int, + keylength: c_int) + -> *mut ffi::EC_KEY + where F: Fn(&mut SslRef, bool, u32) -> Result + Any + 'static + Sync + Send +{ + let ctx = ffi::SSL_get_SSL_CTX(ssl); + let callback = ffi::SSL_CTX_get_ex_data(ctx, get_callback_idx::()); + let callback = &*(callback as *mut F); + + let ssl = SslRef::from_ptr_mut(ssl); + match callback(ssl, is_export != 0, keylength as u32) { + Ok(ec_key) => { + let ptr = ec_key.as_ptr(); + mem::forget(ec_key); + ptr + } + Err(_) => { + // FIXME reset error stack + ptr::null_mut() + } + } +} + /// The function is given as the callback to `SSL_CTX_set_next_protos_advertised_cb`. /// /// It causes the parameter `out` to point at a `*const c_uchar` instance that @@ -472,7 +523,7 @@ impl SslContextBuilder { unsafe { let verify = Box::new(verify); ffi::SSL_CTX_set_ex_data(self.as_ptr(), - get_verify_data_idx::(), + get_callback_idx::(), mem::transmute(verify)); ffi::SSL_CTX_set_verify(self.as_ptr(), mode.bits as c_int, Some(raw_verify::)); } @@ -488,7 +539,7 @@ impl SslContextBuilder { unsafe { let callback = Box::new(callback); ffi::SSL_CTX_set_ex_data(self.as_ptr(), - get_verify_data_idx::(), + get_callback_idx::(), mem::transmute(callback)); let f: extern "C" fn(_, _, _) -> _ = raw_sni::; let f: extern "C" fn() = mem::transmute(f); @@ -520,10 +571,38 @@ impl SslContextBuilder { unsafe { cvt(ffi::SSL_CTX_set_tmp_dh(self.as_ptr(), dh.as_ptr()) as c_int).map(|_| ()) } } + pub fn set_tmp_dh_callback(&mut self, callback: F) + where F: Fn(&mut SslRef, bool, u32) -> Result + Any + 'static + Sync + Send + { + unsafe { + let callback = Box::new(callback); + ffi::SSL_CTX_set_ex_data(self.as_ptr(), + get_callback_idx::(), + Box::into_raw(callback) as *mut c_void); + let f: unsafe extern fn (_, _, _) -> _ = raw_tmp_dh::; + ffi::SSL_CTX_set_tmp_dh_callback(self.as_ptr(), f); + } + } + pub fn set_tmp_ecdh(&mut self, key: &EcKeyRef) -> Result<(), ErrorStack> { unsafe { cvt(ffi::SSL_CTX_set_tmp_ecdh(self.as_ptr(), key.as_ptr()) as c_int).map(|_| ()) } } + /// Requires the `v101` feature and OpenSSL 1.0.1, or the `v102` feature and OpenSSL 1.0.2. + #[cfg(any(all(feature = "v101", ossl101), all(feature = "v102", ossl102)))] + pub fn set_tmp_ecdh_callback(&mut self, callback: F) + where F: Fn(&mut SslRef, bool, u32) -> Result + Any + 'static + Sync + Send + { + unsafe { + let callback = Box::new(callback); + ffi::SSL_CTX_set_ex_data(self.as_ptr(), + get_callback_idx::(), + Box::into_raw(callback) as *mut c_void); + let f: unsafe extern fn(_, _, _) -> _ = raw_tmp_ecdh::; + ffi::SSL_CTX_set_tmp_ecdh_callback(self.as_ptr(), f); + } + } + /// Use the default locations of trusted certificates for verification. /// /// These locations are read from the `SSL_CERT_FILE` and `SSL_CERT_DIR` diff --git a/openssl/src/ssl/tests/mod.rs b/openssl/src/ssl/tests/mod.rs index fa7c6024..d79e5386 100644 --- a/openssl/src/ssl/tests/mod.rs +++ b/openssl/src/ssl/tests/mod.rs @@ -9,10 +9,12 @@ use std::mem; use std::net::{TcpStream, TcpListener, SocketAddr}; use std::path::Path; use std::process::{Command, Child, Stdio, ChildStdin}; +use std::sync::atomic::{AtomicBool, ATOMIC_BOOL_INIT, Ordering}; use std::thread; use std::time::Duration; use tempdir::TempDir; +use dh::Dh; use hash::MessageDigest; use ssl; use ssl::SSL_VERIFY_PEER; @@ -1206,6 +1208,69 @@ fn cert_store() { ctx.connect("foobar.com", tcp).unwrap(); } +#[test] +fn tmp_dh_callback() { + static CALLED_BACK: AtomicBool = ATOMIC_BOOL_INIT; + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + + thread::spawn(move ||{ + let stream = listener.accept().unwrap().0; + let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); + ctx.set_certificate_file(&Path::new("test/cert.pem"), X509_FILETYPE_PEM).unwrap(); + ctx.set_private_key_file(&Path::new("test/key.pem"), X509_FILETYPE_PEM).unwrap(); + ctx.set_tmp_dh_callback(|_, _, _| { + CALLED_BACK.store(true, Ordering::SeqCst); + let dh = include_bytes!("../../../test/dhparams.pem"); + Dh::from_pem(dh) + }); + let ssl = Ssl::new(&ctx.build()).unwrap(); + ssl.accept(stream).unwrap(); + }); + + let stream = TcpStream::connect(("127.0.0.1", port)).unwrap(); + let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); + ctx.set_cipher_list("DHE").unwrap(); + let ssl = Ssl::new(&ctx.build()).unwrap(); + ssl.connect(stream).unwrap(); + + assert!(CALLED_BACK.load(Ordering::SeqCst)); +} + +#[test] +#[cfg(any(all(feature = "v101", ossl101), all(feature = "v102", ossl102)))] +fn tmp_ecdh_callback() { + use ec_key::EcKey; + use nid; + + static CALLED_BACK: AtomicBool = ATOMIC_BOOL_INIT; + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + + thread::spawn(move ||{ + let stream = listener.accept().unwrap().0; + let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); + ctx.set_certificate_file(&Path::new("test/cert.pem"), X509_FILETYPE_PEM).unwrap(); + ctx.set_private_key_file(&Path::new("test/key.pem"), X509_FILETYPE_PEM).unwrap(); + ctx.set_tmp_ecdh_callback(|_, _, _| { + CALLED_BACK.store(true, Ordering::SeqCst); + EcKey::new_by_curve_name(nid::X9_62_PRIME256V1) + }); + let ssl = Ssl::new(&ctx.build()).unwrap(); + ssl.accept(stream).unwrap(); + }); + + let stream = TcpStream::connect(("127.0.0.1", port)).unwrap(); + let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); + ctx.set_cipher_list("ECDHE").unwrap(); + let ssl = Ssl::new(&ctx.build()).unwrap(); + ssl.connect(stream).unwrap(); + + assert!(CALLED_BACK.load(Ordering::SeqCst)); +} + fn _check_kinds() { fn is_send() {} fn is_sync() {}