From 5408b641ddbddd9f40ec203901dd7cb1a7afa3c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20Sch=C3=B6lling?= Date: Wed, 4 Mar 2015 22:32:16 +0100 Subject: [PATCH] Add connect() support for UDP sockets --- openssl-sys/Cargo.toml | 1 + openssl-sys/src/lib.rs | 2 + openssl/Cargo.toml | 3 +- openssl/src/ssl/connected_socket.rs | 301 ++++++++++++++++++++++++++++ openssl/src/ssl/mod.rs | 8 +- openssl/src/ssl/tests.rs | 70 +++++-- 6 files changed, 365 insertions(+), 20 deletions(-) create mode 100644 openssl/src/ssl/connected_socket.rs diff --git a/openssl-sys/Cargo.toml b/openssl-sys/Cargo.toml index ffea6211..7b06dfed 100644 --- a/openssl-sys/Cargo.toml +++ b/openssl-sys/Cargo.toml @@ -14,6 +14,7 @@ build = "build.rs" [features] tlsv1_2 = [] tlsv1_1 = [] +dtlsv1 = [] sslv2 = [] aes_xts = [] npn = [] diff --git a/openssl-sys/src/lib.rs b/openssl-sys/src/lib.rs index a4accc29..c782f816 100644 --- a/openssl-sys/src/lib.rs +++ b/openssl-sys/src/lib.rs @@ -482,6 +482,8 @@ extern "C" { pub fn TLSv1_1_method() -> *const SSL_METHOD; #[cfg(feature = "tlsv1_2")] pub fn TLSv1_2_method() -> *const SSL_METHOD; + #[cfg(feature = "dtlsv1")] + pub fn DTLSv1_method() -> *const SSL_METHOD; pub fn SSLv23_method() -> *const SSL_METHOD; pub fn SSL_new(ctx: *mut SSL_CTX) -> *mut SSL; diff --git a/openssl/Cargo.toml b/openssl/Cargo.toml index 691b3d18..ca3bd417 100644 --- a/openssl/Cargo.toml +++ b/openssl/Cargo.toml @@ -7,11 +7,12 @@ description = "OpenSSL bindings" repository = "https://github.com/sfackler/rust-openssl" documentation = "https://sfackler.github.io/rust-openssl/doc/openssl" readme = "../README.md" -keywords = ["crypto", "tls", "ssl"] +keywords = ["crypto", "tls", "ssl", "dtls"] [features] tlsv1_2 = ["openssl-sys/tlsv1_2"] tlsv1_1 = ["openssl-sys/tlsv1_1"] +dtlsv1 = ["openssl-sys/dtlsv1"] sslv2 = ["openssl-sys/sslv2"] aes_xts = ["openssl-sys/aes_xts"] npn = ["openssl-sys/npn"] diff --git a/openssl/src/ssl/connected_socket.rs b/openssl/src/ssl/connected_socket.rs new file mode 100644 index 00000000..1ae5fc8d --- /dev/null +++ b/openssl/src/ssl/connected_socket.rs @@ -0,0 +1,301 @@ +use libc::funcs::bsd43::connect; +use std::os; +use std::os::unix::AsRawFd; +use std::os::unix::Fd; +use std::net::UdpSocket; +use std::net::ToSocketAddrs; +use std::net::SocketAddr; +use std::io::Error; +use std::io::ErrorKind; +use std::io::Read; +use std::io::Write; +use std::mem; +use std::time::duration::Duration; +use libc::types::os::common::bsd44::socklen_t; +use libc::types::os::common::bsd44::sockaddr_in; +use libc::types::os::common::bsd44::sockaddr_in6; +use libc::types::os::common::bsd44::in_addr; +use libc::types::os::common::bsd44::in6_addr; +use libc::types::os::common::posix01::timeval; +use libc::funcs::bsd43::setsockopt; +use libc::consts::os::bsd44::SOL_SOCKET; +use libc::consts::os::bsd44::AF_INET; +use libc::consts::os::bsd44::AF_INET6; +use libc::consts::os::posix88::EAGAIN; +use std::net::IpAddr; +use libc::types::os::arch::c95::c_int; +use libc::types::os::arch::c95::c_char; +use libc::types::common::c95::c_void; +use libc::funcs::bsd43::send; +use libc::funcs::bsd43::recv; +use std::num::Int; +use std::os::errno; +use std::ffi::CString; + +const SO_RCVTIMEO:c_int = 20; + +extern { + fn inet_pton(family: c_int, src: *const c_char, dst: *mut c_void) -> c_int; +} + +pub struct ConnectedSocket { + sock: S +} + +impl AsRawFd for ConnectedSocket { + fn as_raw_fd(&self) -> Fd { + self.sock.as_raw_fd() + } +} + +enum SockaddrIn { + V4(sockaddr_in), + V6(sockaddr_in6), +} + +trait IntoSockaddrIn { + fn into_sockaddr_in(self) -> Result; +} + +impl IntoSockaddrIn for SocketAddr { + fn into_sockaddr_in(self) -> Result { + let ip = format!("{}", self.ip()); + + match self.ip() { + IpAddr::V4(_) => { + let mut addr = sockaddr_in { + sin_zero: [0; 8], + sin_family: AF_INET as u16, + sin_port: Int::to_be(self.port()), + sin_addr: in_addr { + s_addr: 0 + } + }; + let cstr = CString::new(ip.clone()).unwrap(); + let res = unsafe { + inet_pton(addr.sin_family as c_int, + cstr.as_ptr() as *const i8, + mem::transmute(&mut addr.sin_addr)) + }; + + if res == 1 { + Ok(SockaddrIn::V4(addr)) + } else { + warn!("inet_pton() failed for IPv4: {}", ip); + Err(Error::new(ErrorKind::Other, + "calling inet_pton() for ipv4", None)) + } + }, + + IpAddr::V6(_) => { + let mut addr = sockaddr_in6 { + sin6_family: AF_INET6 as u16, + sin6_port: Int::to_be(self.port()), + sin6_flowinfo: 0, + sin6_scope_id: 0, + sin6_addr: in6_addr { + s6_addr: [0; 8], + } + }; + let cstr = CString::new(ip.clone()).unwrap(); + let res = unsafe { + inet_pton(addr.sin6_family as c_int, + cstr.as_ptr() as *const i8, + mem::transmute(&mut addr.sin6_addr)) + }; + + if res > 0 { + Ok(SockaddrIn::V6(addr)) + } else { + Err(Error::new(ErrorKind::Other, + "calling inet_pton() for ipv6", None)) + } + } + } + } +} + +pub trait Connect { + fn connect(self, addr: &A) -> Result,Error>; +} + +impl Connect for UdpSocket { + fn connect(self, address: &A) -> Result,Error> { + let fd = self.as_raw_fd(); + + let addr = try!(address.to_socket_addrs()).next(); + if addr.is_none() { + return Err(Error::new(ErrorKind::InvalidInput, + "no addresses to connect to", None)); + } + + let saddr = try!(addr.unwrap().into_sockaddr_in()); + + let res = match saddr { + SockaddrIn::V4(s) => unsafe { + let len = mem::size_of_val(&s) as socklen_t; + let addrp = Box::new(s); + connect(fd, mem::transmute(&*addrp), len) + }, + SockaddrIn::V6(s) => unsafe { + let len = mem::size_of_val(&s) as socklen_t; + let addrp = Box::new(s); + connect(fd, mem::transmute(&*addrp), len) + }, + }; + + if res == 0 { + Ok(ConnectedSocket { sock: self }) + } else { + Err(Error::new(ErrorKind::Other, + "error calling connect()", None)) + } + } +} + +impl Read for ConnectedSocket { + fn read(&mut self, buf: &mut [u8]) -> Result { + let flags = 0; + let ptr = buf.as_mut_ptr() as *mut c_void; + + debug!("recv'ing..."); + let len = unsafe { + recv(self.as_raw_fd(), ptr, buf.len() as u64, flags) + }; + + debug!("recv'ed len={:?}", len); + match len { + -1 => { + match errno() { + EAGAIN => Err(Error::new(ErrorKind::Interrupted, "EAGAIN", None)), + _ => Err(Error::new(ErrorKind::Other, + "recv() returned -1", None)), + } + }, + 0 => Err(Error::new(ErrorKind::Other, + "connection is closed", None)), + _ => Ok(len as usize), + } + } +} + +impl Write for ConnectedSocket { + fn write(&mut self, buf: &[u8]) -> Result { + let flags = 0; + let ptr = buf.as_ptr() as *const c_void; + + debug!("sending {:?}", buf.len()); + let res = unsafe { + send(self.as_raw_fd(), ptr, buf.len() as u64, flags) + }; + if res == (buf.len() as i64) { + Ok(res as usize) + } else { + warn!("send() found {}, expected {}", res, buf.len()); + Err(Error::new(ErrorKind::Other, "send() failed", Some(os::error_string(os::errno() as i32)))) + } + } + + fn flush(&mut self) -> Result<(),Error> { + Ok(()) + } +} + +pub trait SetTimeout { + fn set_timeout(&self, timeout: Duration); +} + +impl SetTimeout for S { + fn set_timeout(&self, timeout: Duration) { + let tv = timeval { + tv_sec: timeout.num_seconds(), + tv_usec: 0, + }; + + unsafe { + setsockopt(self.as_raw_fd(), SOL_SOCKET, SO_RCVTIMEO, + mem::transmute(&tv), mem::size_of_val(&tv) as u32) + }; + } +} + +#[test] +fn connect4_works() { + let socket1 = UdpSocket::bind("127.0.0.1:34200").unwrap(); + let socket2 = UdpSocket::bind("127.0.0.1:34201").unwrap(); + let conn1 = socket1.connect("127.0.0.1:34200").unwrap(); + let conn2 = socket2.connect("127.0.0.1:34201").unwrap(); +} + +#[test] +fn sendrecv_works() { + let socket1 = UdpSocket::bind("127.0.0.1:34200").unwrap(); + let socket2 = UdpSocket::bind("127.0.0.1:34201").unwrap(); + let mut conn1 = socket1.connect("127.0.0.1:34201").unwrap(); + let mut conn2 = socket2.connect("127.0.0.1:34200").unwrap(); + + let send1 = [0,1,2,3]; + let send2 = [9,8,7,6]; + conn1.write(&send1).unwrap(); + conn2.write(&send2).unwrap(); + + let mut recv1 = [0;4]; + let mut recv2 = [0;4]; + conn1.read(&mut recv1).unwrap(); + conn2.read(&mut recv2).unwrap(); + + assert_eq!(send1, recv2); + assert_eq!(send2, recv1); +} + +#[test] +fn sendrecv_respects_packet_borders() { + let socket1 = UdpSocket::bind("127.0.0.1:34202").unwrap(); + let socket2 = UdpSocket::bind("127.0.0.1:34203").unwrap(); + let mut conn1 = socket1.connect("127.0.0.1:34203").unwrap(); + let mut conn2 = socket2.connect("127.0.0.1:34202").unwrap(); + + let send1 = [0,1,2,3]; + let send2 = [9,8,7,6]; + conn1.write(&send1).unwrap(); + conn1.write(&send2).unwrap(); + + let mut recv1 = [0;3]; + let mut recv2 = [0;3]; + conn2.read(&mut recv1).unwrap(); + conn2.read(&mut recv2).unwrap(); + + assert!(send1[0..3] == recv1[0..3]); + assert!(send2[0..3] == recv2[0..3]); +} + +#[test] +fn connect6_works() { + let socket1 = UdpSocket::bind("::1:34200").unwrap(); + let socket2 = UdpSocket::bind("::1:34201").unwrap(); + let conn1 = socket1.connect("::1:34200").unwrap(); + let conn2 = socket2.connect("::1:34201").unwrap(); +} + +#[test] +#[should_fail] +fn detect_invalid_ipv4() { + let s = UdpSocket::bind("127.0.0.1:34300").unwrap(); + s.connect("254.254.254.254:34200").unwrap(); +} + +#[test] +#[should_fail] +fn detect_invalid_ipv6() { + let s = UdpSocket::bind("::1:34300").unwrap(); + s.connect("1200::AB00:1234::2552:7777:1313:34300").unwrap(); +} + +#[test] +#[should_fail] +fn double_bind() { + let socket1 = UdpSocket::bind("127.0.0.1:34301").unwrap(); + let socket2 = UdpSocket::bind("127.0.0.1:34301").unwrap(); + drop(socket1); + drop(socket2); +} diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 4c0b13f1..710a287d 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -25,6 +25,7 @@ use x509::{X509StoreContext, X509FileType, X509}; use crypto::pkey::PKey; pub mod error; +pub mod connected_socket; #[cfg(test)] mod tests; @@ -97,6 +98,9 @@ pub enum SslMethod { #[cfg(feature = "tlsv1_2")] /// Support TLSv1.2 protocol, requires the `tlsv1_2` feature. Tlsv1_2, + #[cfg(feature = "dtlsv1")] + /// Support DTLSv1 protocol, requires the `dtlsv1` feature. + Dtlsv1, } impl SslMethod { @@ -110,7 +114,9 @@ impl SslMethod { #[cfg(feature = "tlsv1_1")] SslMethod::Tlsv1_1 => ffi::TLSv1_1_method(), #[cfg(feature = "tlsv1_2")] - SslMethod::Tlsv1_2 => ffi::TLSv1_2_method() + SslMethod::Tlsv1_2 => ffi::TLSv1_2_method(), + #[cfg(feature = "dtlsv1")] + SslMethod::Dtlsv1 => ffi::TLSv1_method(), } } } diff --git a/openssl/src/ssl/tests.rs b/openssl/src/ssl/tests.rs index 05c9fe79..1da42082 100644 --- a/openssl/src/ssl/tests.rs +++ b/openssl/src/ssl/tests.rs @@ -11,6 +11,7 @@ use std::fs::File; use crypto::hash::Type::{SHA256}; use ssl; +use ssl::SslMethod; use ssl::SslMethod::Sslv23; use ssl::{SslContext, SslStream, VerifyCallback}; use ssl::SSL_VERIFY_PEER; @@ -20,21 +21,23 @@ use x509::X509FileType; use x509::X509; use crypto::pkey::PKey; +const PROTOCOL:SslMethod = Sslv23; + #[test] fn test_new_ctx() { - SslContext::new(Sslv23).unwrap(); + SslContext::new(PROTOCOL).unwrap(); } #[test] fn test_new_sslstream() { let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - SslStream::new(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); + SslStream::new(&SslContext::new(PROTOCOL).unwrap(), stream).unwrap(); } #[test] fn test_verify_untrusted() { let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut ctx = SslContext::new(Sslv23).unwrap(); + let mut ctx = SslContext::new(PROTOCOL).unwrap(); ctx.set_verify(SSL_VERIFY_PEER, None); match SslStream::new(&ctx, stream) { Ok(_) => panic!("expected failure"), @@ -45,8 +48,9 @@ fn test_verify_untrusted() { #[test] fn test_verify_trusted() { let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut ctx = SslContext::new(Sslv23).unwrap(); + let mut ctx = SslContext::new(PROTOCOL).unwrap(); ctx.set_verify(SSL_VERIFY_PEER, None); + match ctx.set_CA_file(&Path::new("test/cert.pem")) { Ok(_) => {} Err(err) => panic!("Unexpected error {:?}", err) @@ -63,8 +67,9 @@ fn test_verify_untrusted_callback_override_ok() { true } let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut ctx = SslContext::new(Sslv23).unwrap(); + let mut ctx = SslContext::new(PROTOCOL).unwrap(); ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback)); + match SslStream::new(&ctx, stream) { Ok(_) => (), Err(err) => panic!("Expected success, got {:?}", err) @@ -77,8 +82,9 @@ fn test_verify_untrusted_callback_override_bad() { false } let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut ctx = SslContext::new(Sslv23).unwrap(); + let mut ctx = SslContext::new(PROTOCOL).unwrap(); ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback)); + assert!(SslStream::new(&ctx, stream).is_err()); } @@ -88,8 +94,9 @@ fn test_verify_trusted_callback_override_ok() { true } let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut ctx = SslContext::new(Sslv23).unwrap(); + let mut ctx = SslContext::new(PROTOCOL).unwrap(); ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback)); + match ctx.set_CA_file(&Path::new("test/cert.pem")) { Ok(_) => {} Err(err) => panic!("Unexpected error {:?}", err) @@ -106,8 +113,9 @@ fn test_verify_trusted_callback_override_bad() { false } let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut ctx = SslContext::new(Sslv23).unwrap(); + let mut ctx = SslContext::new(PROTOCOL).unwrap(); ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback)); + match ctx.set_CA_file(&Path::new("test/cert.pem")) { Ok(_) => {} Err(err) => panic!("Unexpected error {:?}", err) @@ -122,8 +130,9 @@ fn test_verify_callback_load_certs() { true } let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut ctx = SslContext::new(Sslv23).unwrap(); + let mut ctx = SslContext::new(PROTOCOL).unwrap(); ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback)); + assert!(SslStream::new(&ctx, stream).is_ok()); } @@ -134,8 +143,9 @@ fn test_verify_trusted_get_error_ok() { true } let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut ctx = SslContext::new(Sslv23).unwrap(); + let mut ctx = SslContext::new(PROTOCOL).unwrap(); ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback)); + match ctx.set_CA_file(&Path::new("test/cert.pem")) { Ok(_) => {} Err(err) => panic!("Unexpected error {:?}", err) @@ -150,8 +160,9 @@ fn test_verify_trusted_get_error_err() { false } let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut ctx = SslContext::new(Sslv23).unwrap(); + let mut ctx = SslContext::new(PROTOCOL).unwrap(); ctx.set_verify(SSL_VERIFY_PEER, Some(callback as VerifyCallback)); + assert!(SslStream::new(&ctx, stream).is_err()); } @@ -168,7 +179,7 @@ fn test_verify_callback_data() { } } let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut ctx = SslContext::new(Sslv23).unwrap(); + let mut ctx = SslContext::new(PROTOCOL).unwrap(); // Node id was generated as SHA256 hash of certificate "test/cert.pem" // in DER format. @@ -234,7 +245,7 @@ fn test_clear_ctx_options() { #[test] fn test_write() { let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut stream = SslStream::new(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); + let mut stream = SslStream::new(&SslContext::new(PROTOCOL).unwrap(), stream).unwrap(); stream.write_all("hello".as_bytes()).unwrap(); stream.flush().unwrap(); stream.write_all(" there".as_bytes()).unwrap(); @@ -244,7 +255,7 @@ fn test_write() { #[test] fn test_read() { let stream = TcpStream::connect("127.0.0.1:15418").unwrap(); - let mut stream = SslStream::new(&SslContext::new(Sslv23).unwrap(), stream).unwrap(); + let mut stream = SslStream::new(&SslContext::new(PROTOCOL).unwrap(), stream).unwrap(); stream.write_all("GET /\r\n\r\n".as_bytes()).unwrap(); stream.flush().unwrap(); println!("written"); @@ -261,7 +272,7 @@ fn test_connect_with_unilateral_npn() { ctx.set_verify(SSL_VERIFY_PEER, None); ctx.set_npn_protocols(&[b"http/1.1", b"spdy/3.1"]); match ctx.set_CA_file(&Path::new("test/cert.pem")) { - Ok(_)=> {} + Ok(_) => {} Err(err) => panic!("Unexpected error {:?}", err) } let stream = match SslStream::new(&ctx, stream) { @@ -285,7 +296,7 @@ fn test_connect_with_npn_successful_multiple_matching() { ctx.set_verify(SSL_VERIFY_PEER, None); ctx.set_npn_protocols(&[b"spdy/3.1", b"http/1.1"]); match ctx.set_CA_file(&Path::new("test/cert.pem")) { - Ok(_)=> {} + Ok(_) => {} Err(err) => panic!("Unexpected error {:?}", err) } let stream = match SslStream::new(&ctx, stream) { @@ -310,7 +321,7 @@ fn test_connect_with_npn_successful_single_match() { ctx.set_verify(SSL_VERIFY_PEER, None); ctx.set_npn_protocols(&[b"spdy/3.1"]); match ctx.set_CA_file(&Path::new("test/cert.pem")) { - Ok(_)=> {} + Ok(_) => {} Err(err) => panic!("Unexpected error {:?}", err) } let stream = match SslStream::new(&ctx, stream) { @@ -350,7 +361,7 @@ fn test_npn_server_advertise_multiple() { ctx.set_verify(SSL_VERIFY_PEER, None); ctx.set_npn_protocols(&[b"spdy/3.1"]); match ctx.set_CA_file(&Path::new("test/cert.pem")) { - Ok(_)=> {} + Ok(_) => {} Err(err) => panic!("Unexpected error {:?}", err) } // Now connect to the socket and make sure the protocol negotiation works... @@ -362,3 +373,26 @@ fn test_npn_server_advertise_multiple() { // SPDY is selected since that's the only thing the client supports. assert_eq!(b"spdy/3.1", stream.get_selected_npn_protocol().unwrap()); } + +#[cfg(feature="dtlsv1")] +#[cfg(test)] +mod dtlsv1 { + use serialize::hex::FromHex; + use std::old_io::net::tcp::TcpStream; + use std::old_io::{Writer}; + use std::thread; + + use crypto::hash::Type::{SHA256}; + use ssl::SslMethod; + use ssl::SslMethod::Dtlsv1; + use ssl::{SslContext, SslStream, VerifyCallback}; + use ssl::SslVerifyMode::SSL_VERIFY_PEER; + use x509::{X509StoreContext}; + + const PROTOCOL:SslMethod = Dtlsv1; + + #[test] + fn test_new_ctx() { + SslContext::new(PROTOCOL).unwrap(); + } +}