diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index d19f6678..058231ca 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -145,6 +145,34 @@ fn get_verify_data_idx() -> c_int { } } +/// Creates a static index for the list of NPN protocols. +/// Registers a destructor for the data which will be called +/// when the context is freed. +#[cfg(feature = "npn")] +fn get_npn_protos_idx() -> c_int { + static mut NPN_PROTOS_IDX: c_int = -1; + static mut INIT: Once = ONCE_INIT; + + extern fn free_data_box(_parent: *mut c_void, ptr: *mut c_void, + _ad: *mut ffi::CRYPTO_EX_DATA, _idx: c_int, + _argl: c_long, _argp: *mut c_void) { + if !ptr.is_null() { + let _: Box> = unsafe { mem::transmute(ptr) }; + } + } + + unsafe { + INIT.call_once(|| { + let f: ffi::CRYPTO_EX_free = free_data_box; + let idx = ffi::SSL_CTX_get_ex_new_index(0, ptr::null(), None, + None, Some(f)); + assert!(idx >= 0); + NPN_PROTOS_IDX = idx; + }); + NPN_PROTOS_IDX + } +} + extern fn raw_verify(preverify_ok: c_int, x509_ctx: *mut ffi::X509_STORE_CTX) -> c_int { unsafe { @@ -342,6 +370,31 @@ impl SslContext { }; SslContextOptions::from_bits(ret).unwrap() } + + /// Set the protocols to be used during Next Protocol Negotiation (the protocols + /// supported by the application). + /// + /// This method needs the `npn` feature. + #[cfg(feature = "npn")] + pub fn set_npn_protocols(&mut self, protocols: &[&[u8]]) { + // Firstly, convert the list of protocols to a byte-array that can be passed to OpenSSL + // APIs -- a list of length-prefixed strings. + let mut npn_protocols = Vec::new(); + for protocol in protocols { + let len = protocol.len() as u8; + npn_protocols.push(len); + // If the length is greater than the max `u8`, this truncates the protocol name. + npn_protocols.extend(protocol[..len as usize].to_vec()); + } + let protocols: Box> = Box::new(npn_protocols); + + unsafe { + // Attach the protocol list to the OpenSSL context structure, + // so that we can refer to it within the callback. + ffi::SSL_CTX_set_ex_data(*self.ctx, get_npn_protos_idx(), + mem::transmute(protocols)); + } + } } #[allow(dead_code)]