diff --git a/openssl-sys/src/lib.rs b/openssl-sys/src/lib.rs index fda47fd0..e7bd046e 100644 --- a/openssl-sys/src/lib.rs +++ b/openssl-sys/src/lib.rs @@ -1453,6 +1453,10 @@ pub unsafe fn BIO_set_retry_write(b: *mut BIO) { BIO_set_flags(b, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY) } +pub unsafe fn EVP_get_digestbynid(type_: c_int) -> *const EVP_MD { + EVP_get_digestbyname(OBJ_nid2sn(type_)) +} + // EVP_PKEY_CTX_ctrl macros pub unsafe fn EVP_PKEY_CTX_set_rsa_padding(ctx: *mut EVP_PKEY_CTX, pad: c_int) -> c_int { EVP_PKEY_CTX_ctrl( @@ -2103,6 +2107,8 @@ extern "C" { no_name: c_int, ) -> c_int; pub fn OBJ_nid2sn(nid: c_int) -> *const c_char; + pub fn OBJ_find_sigid_algs(signid: c_int, pdig_nid: *mut c_int, ppkey_nid: *mut c_int) + -> c_int; pub fn OCSP_BASICRESP_new() -> *mut OCSP_BASICRESP; pub fn OCSP_BASICRESP_free(r: *mut OCSP_BASICRESP); @@ -2840,6 +2846,7 @@ extern "C" { ); pub fn EVP_MD_size(md: *const EVP_MD) -> c_int; + pub fn EVP_get_digestbyname(name: *const c_char) -> *const EVP_MD; pub fn EVP_get_cipherbyname(name: *const c_char) -> *const EVP_CIPHER; pub fn SSL_set_connect_state(s: *mut SSL); diff --git a/openssl/src/hash.rs b/openssl/src/hash.rs index 582a2ada..49797df8 100644 --- a/openssl/src/hash.rs +++ b/openssl/src/hash.rs @@ -4,6 +4,10 @@ use std::io; use std::io::prelude::*; use std::ops::{Deref, DerefMut}; +use error::ErrorStack; +use nid::Nid; +use {cvt, cvt_p}; + cfg_if! { if #[cfg(ossl110)] { use ffi::{EVP_MD_CTX_free, EVP_MD_CTX_new}; @@ -12,9 +16,6 @@ cfg_if! { } } -use error::ErrorStack; -use {cvt, cvt_p}; - #[derive(Copy, Clone)] pub struct MessageDigest(*const ffi::EVP_MD); @@ -23,6 +24,22 @@ impl MessageDigest { MessageDigest(x) } + /// Returns the `MessageDigest` corresponding to an `Nid`. + /// + /// This corresponds to [`EVP_get_digestbynid`]. + /// + /// [`EVP_get_digestbynid`]: https://www.openssl.org/docs/man1.1.0/crypto/EVP_DigestInit.html + pub fn from_nid(type_: Nid) -> Option { + unsafe { + let ptr = ffi::EVP_get_digestbynid(type_.as_raw()); + if ptr.is_null() { + None + } else { + Some(MessageDigest(ptr)) + } + } + } + pub fn md5() -> MessageDigest { unsafe { MessageDigest(ffi::EVP_md5()) } } @@ -405,4 +422,12 @@ mod tests { hash_test(MessageDigest::ripemd160(), test); } } + + #[test] + fn from_nid() { + assert_eq!( + MessageDigest::from_nid(Nid::SHA256).unwrap().as_ptr(), + MessageDigest::sha256().as_ptr() + ); + } } diff --git a/openssl/src/nid.rs b/openssl/src/nid.rs index 7c041236..78ffac96 100644 --- a/openssl/src/nid.rs +++ b/openssl/src/nid.rs @@ -1,6 +1,7 @@ //! A collection of numerical identifiers for OpenSSL objects. use ffi; use libc::c_int; +use std::ptr; /// A numerical identifier for an OpenSSL object. /// @@ -42,6 +43,20 @@ impl Nid { self.0 } + /// Returns the `Nid` of the digest algorithm associated with a signature ID. + /// + /// This corresponds to `OBJ_find_sigid_algs`. + pub fn digest_algorithm(&self) -> Option { + unsafe { + let mut digest = 0; + if ffi::OBJ_find_sigid_algs(self.0, &mut digest, ptr::null_mut()) == 1 { + Some(Nid(digest)) + } else { + None + } + } + } + pub const UNDEF: Nid = Nid(ffi::NID_undef); pub const ITU_T: Nid = Nid(ffi::NID_itu_t); pub const CCITT: Nid = Nid(ffi::NID_ccitt); @@ -991,3 +1006,16 @@ impl Nid { pub const AES_192_CBC_HMAC_SHA1: Nid = Nid(ffi::NID_aes_192_cbc_hmac_sha1); pub const AES_256_CBC_HMAC_SHA1: Nid = Nid(ffi::NID_aes_256_cbc_hmac_sha1); } + +#[cfg(test)] +mod test { + use super::Nid; + + #[test] + fn signature_digest() { + assert_eq!( + Nid::SHA256WITHRSAENCRYPTION.digest_algorithm(), + Some(Nid::SHA256) + ); + } +}