diff --git a/openssl-sys/src/evp.rs b/openssl-sys/src/evp.rs index b248f7f9..8a057a0d 100644 --- a/openssl-sys/src/evp.rs +++ b/openssl-sys/src/evp.rs @@ -403,3 +403,8 @@ extern "C" { pub fn EVP_PKEY_keygen_init(ctx: *mut EVP_PKEY_CTX) -> c_int; pub fn EVP_PKEY_keygen(ctx: *mut EVP_PKEY_CTX, key: *mut *mut EVP_PKEY) -> c_int; } + +extern "C" { + pub fn EVP_EncodeBlock(dst: *mut c_uchar, src: *const c_uchar, src_len: c_int) -> c_int; + pub fn EVP_DecodeBlock(dst: *mut c_uchar, src: *const c_uchar, src_len: c_int) -> c_int; +} diff --git a/openssl/src/base64.rs b/openssl/src/base64.rs new file mode 100644 index 00000000..42708c42 --- /dev/null +++ b/openssl/src/base64.rs @@ -0,0 +1,125 @@ +//! Utilities for base64 coding +//! +//! See manual page of [`EVP_EncodeInit`] for more information on the specific base64 variant. +//! +//! [`EVP_EncodeInit`]: https://www.openssl.org/docs/man1.1.1/man3/EVP_EncodeInit.html +use error::ErrorStack; +use ffi; +use libc::c_int; +use cvt_n; + +/// Encodes a given block of bytes to base64. +/// +/// # Panics +/// +/// Panics if the input length or computed output length +/// overflow a signed C integer. +pub fn encode_block(src: &[u8]) -> String { + assert!(src.len() <= c_int::max_value() as usize); + let src_len = src.len() as c_int; + + let len = encoded_len(src_len).unwrap(); + let mut out = Vec::new(); + out.reserve(len as usize); + + // SAFETY: `encoded_len` ensures space for 4 output characters + // for every 3 input bytes including padding and nul terminator. + // `EVP_EncodeBlock` will write only single byte ASCII characters. + // `EVP_EncodeBlock` will only write to not read from `out`. + unsafe { + let out_len = ffi::EVP_EncodeBlock(out.as_mut_ptr(), src.as_ptr(), src_len); + out.set_len(out_len as usize); + String::from_utf8_unchecked(out) + } +} + +/// Decodes a given base64-encoded text to bytes. +/// +/// # Panics +/// +/// Panics if the input length or computed output length +/// overflow a signed C integer. +pub fn decode_block(src: &str) -> Result, ErrorStack> { + let src = src.trim(); + + assert!(src.len() <= c_int::max_value() as usize); + let src_len = src.len() as c_int; + + let len = decoded_len(src_len).unwrap(); + let mut out = Vec::new(); + out.reserve(len as usize); + + // SAFETY: `decoded_len` ensures space for 3 output bytes + // for every 4 input characters including padding. + // `EVP_DecodeBlock` can write fewer bytes after stripping + // leading and trailing whitespace, but never more. + // `EVP_DecodeBlock` will only write to not read from `out`. + unsafe { + let out_len = cvt_n(ffi::EVP_DecodeBlock(out.as_mut_ptr(), src.as_ptr(), src_len))?; + out.set_len(out_len as usize); + } + + if src.ends_with("=") { + out.pop(); + if src.ends_with("==") { + out.pop(); + } + } + + Ok(out) +} + +fn encoded_len(src_len: c_int) -> Option { + let mut len = (src_len / 3).checked_mul(4)?; + + if src_len % 3 != 0 { + len = len.checked_add(4)?; + } + + len = len.checked_add(1)?; + + Some(len) +} + +fn decoded_len(src_len: c_int) -> Option { + let mut len = (src_len / 4).checked_mul(3)?; + + if src_len % 4 != 0 { + len = len.checked_add(3)?; + } + + Some(len) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encode_block() { + assert_eq!("".to_string(), encode_block(b"")); + assert_eq!("Zg==".to_string(), encode_block(b"f")); + assert_eq!("Zm8=".to_string(), encode_block(b"fo")); + assert_eq!("Zm9v".to_string(), encode_block(b"foo")); + assert_eq!("Zm9vYg==".to_string(), encode_block(b"foob")); + assert_eq!("Zm9vYmE=".to_string(), encode_block(b"fooba")); + assert_eq!("Zm9vYmFy".to_string(), encode_block(b"foobar")); + } + + #[test] + fn test_decode_block() { + assert_eq!(b"".to_vec(), decode_block("").unwrap()); + assert_eq!(b"f".to_vec(), decode_block("Zg==").unwrap()); + assert_eq!(b"fo".to_vec(), decode_block("Zm8=").unwrap()); + assert_eq!(b"foo".to_vec(), decode_block("Zm9v").unwrap()); + assert_eq!(b"foob".to_vec(), decode_block("Zm9vYg==").unwrap()); + assert_eq!(b"fooba".to_vec(), decode_block("Zm9vYmE=").unwrap()); + assert_eq!(b"foobar".to_vec(), decode_block("Zm9vYmFy").unwrap()); + } + + #[test] + fn test_strip_whitespace() { + assert_eq!(b"foobar".to_vec(), decode_block(" Zm9vYmFy\n").unwrap()); + assert_eq!(b"foob".to_vec(), decode_block(" Zm9vYg==\n").unwrap()); + } +} diff --git a/openssl/src/lib.rs b/openssl/src/lib.rs index e9ee2e21..59e65649 100644 --- a/openssl/src/lib.rs +++ b/openssl/src/lib.rs @@ -140,6 +140,7 @@ mod bio; mod util; pub mod aes; pub mod asn1; +pub mod base64; pub mod bn; #[cfg(not(libressl))] pub mod cms;