Move connected_socket to its own crate and fix SSL_CTX_set_read_ahead()

This commit is contained in:
Manuel Schölling 2015-03-15 15:52:09 +01:00
parent 5788f3bec8
commit dbef985e32
6 changed files with 24 additions and 357 deletions

View File

@ -283,6 +283,9 @@ pub unsafe fn SSL_CTX_add_extra_chain_cert(ssl: *mut SSL_CTX, cert: *mut X509) -
SSL_CTX_ctrl(ssl, SSL_CTRL_EXTRA_CHAIN_CERT, 0, cert)
}
pub unsafe fn SSL_CTX_set_read_ahead(ctx: *mut SSL_CTX, m: c_long) -> c_long {
SSL_CTX_ctrl(ctx, SSL_CTRL_SET_READ_AHEAD, m, ptr::null_mut())
}
// True functions
extern "C" {

View File

@ -12,8 +12,8 @@ keywords = ["crypto", "tls", "ssl", "dtls"]
[features]
tlsv1_2 = ["openssl-sys/tlsv1_2"]
tlsv1_1 = ["openssl-sys/tlsv1_1"]
dtlsv1 = ["openssl-sys/dtlsv1"]
dtlsv1_2 = ["openssl-sys/dtlsv1_2"]
dtlsv1 = ["openssl-sys/dtlsv1", "connected_socket"]
dtlsv1_2 = ["openssl-sys/dtlsv1_2", "connected_socket"]
sslv2 = ["openssl-sys/sslv2"]
aes_xts = ["openssl-sys/aes_xts"]
npn = ["openssl-sys/npn"]
@ -29,3 +29,10 @@ libc = "0.1"
[dev-dependencies]
rustc-serialize = "0.3"
[dependencies]
bitflags = "0.1.1"
[dependencies-dev.connected_socket]
connected_socket = "0.0.1"
optional = true

View File

@ -10,6 +10,9 @@ extern crate openssl_sys as ffi;
#[cfg(test)]
extern crate rustc_serialize as serialize;
#[cfg(any(feature="dtlsv1", feature="dtlsv1_2"))]
extern crate connected_socket;
mod macros;
pub mod asn1;

View File

@ -1,346 +0,0 @@
use libc::funcs::bsd43::connect;
use std::os;
use std::os::unix::AsRawFd;
use std::os::unix::Fd;
use std::net::UdpSocket;
use std::net::ToSocketAddrs;
use std::net::SocketAddr;
use std::io::Error;
use std::io::ErrorKind;
use std::io::Read;
use std::io::Write;
use std::mem;
use std::time::duration::Duration;
use libc::types::os::common::bsd44::socklen_t;
use libc::types::os::common::bsd44::sockaddr_in;
use libc::types::os::common::bsd44::sockaddr_in6;
use libc::types::os::common::bsd44::in_addr;
use libc::types::os::common::bsd44::in6_addr;
use libc::types::os::common::bsd44::sa_family_t;
use libc::types::os::common::posix01::timeval;
use libc::funcs::bsd43::setsockopt;
use libc::consts::os::bsd44::SOL_SOCKET;
use libc::consts::os::bsd44::AF_INET;
use libc::consts::os::bsd44::AF_INET6;
use libc::consts::os::posix88::EAGAIN;
use std::net::IpAddr;
use libc::types::os::arch::c95::c_int;
use libc::types::os::arch::c95::c_char;
use libc::types::common::c95::c_void;
use libc::funcs::bsd43::send;
use libc::funcs::bsd43::recv;
use std::num::Int;
use std::os::errno;
use std::ffi::CString;
const SO_RCVTIMEO:c_int = 20;
extern {
fn inet_pton(family: c_int, src: *const c_char, dst: *mut c_void) -> c_int;
}
pub struct ConnectedSocket<S: ?Sized> {
sock: S
}
impl<S: AsRawFd+?Sized> AsRawFd for ConnectedSocket<S> {
fn as_raw_fd(&self) -> Fd {
self.sock.as_raw_fd()
}
}
enum SockaddrIn {
V4(sockaddr_in),
V6(sockaddr_in6),
}
#[cfg(any(target_os = "linux", target_os = "android", target_os = "nacl",
target_os = "windows"))]
fn new_sockaddr_in() -> sockaddr_in {
sockaddr_in {
sin_family: AF_INET as sa_family_t,
sin_port: 9,
sin_zero: [0; 8],
sin_addr: in_addr {
s_addr: 0
}
}
}
#[cfg(not(any(target_os = "linux", target_os = "android", target_os = "nacl",
target_os = "windows")))]
fn new_sockaddr_in() -> sockaddr_in {
sockaddr_in {
sin_len: mem::size_of::<sockaddr_in>() as u8,
sin_family: AF_INET as sa_family_t,
sin_port: 0,
sin_zero: [0; 8],
sin_addr: in_addr {
s_addr: 0
}
}
}
#[cfg(any(target_os = "linux", target_os = "android", target_os = "nacl",
target_os = "windows"))]
fn new_sockaddr_in6() -> sockaddr_in6 {
sockaddr_in6 {
sin6_family: AF_INET6 as sa_family_t,
sin6_port: 0,
sin6_flowinfo: 0,
sin6_scope_id: 0,
sin6_addr: in6_addr {
s6_addr: [0; 8],
}
}
}
#[cfg(not(any(target_os = "linux", target_os = "android", target_os = "nacl",
target_os = "windows")))]
fn new_sockaddr_in6() -> sockaddr_in6 {
sockaddr_in6 {
sin6_len: mem::size_of::<sockaddr_in6>() as u8,
sin6_family: AF_INET6 as sa_family_t,
sin6_port: 0,
sin6_flowinfo: 0,
sin6_scope_id: 0,
sin6_addr: in6_addr {
s6_addr: [0; 8],
}
}
}
trait IntoSockaddrIn {
fn into_sockaddr_in(self) -> Result<SockaddrIn, Error>;
}
impl IntoSockaddrIn for SocketAddr {
fn into_sockaddr_in(self) -> Result<SockaddrIn, Error> {
let ip = format!("{}", self.ip());
match self.ip() {
IpAddr::V4(_) => {
let mut addr = new_sockaddr_in();
addr.sin_port = Int::to_be(self.port());
let cstr = CString::new(ip.clone()).unwrap();
let res = unsafe {
inet_pton(addr.sin_family as c_int,
cstr.as_ptr() as *const i8,
mem::transmute(&mut addr.sin_addr))
};
if res == 1 {
Ok(SockaddrIn::V4(addr))
} else {
Err(Error::new(ErrorKind::Other,
"calling inet_pton() for ipv4", None))
}
},
IpAddr::V6(_) => {
let mut addr = new_sockaddr_in6();
addr.sin6_port = Int::to_be(self.port());
let cstr = CString::new(ip.clone()).unwrap();
let res = unsafe {
inet_pton(addr.sin6_family as c_int,
cstr.as_ptr() as *const i8,
mem::transmute(&mut addr.sin6_addr))
};
if res > 0 {
Ok(SockaddrIn::V6(addr))
} else {
Err(Error::new(ErrorKind::Other,
"calling inet_pton() for ipv6", None))
}
}
}
}
}
pub trait Connect {
fn connect<A: ToSocketAddrs + ?Sized>(self, addr: &A) -> Result<ConnectedSocket<Self>,Error>;
}
impl Connect for UdpSocket {
fn connect<A: ToSocketAddrs + ?Sized>(self, address: &A) -> Result<ConnectedSocket<Self>,Error> {
let fd = self.as_raw_fd();
let addr = try!(address.to_socket_addrs()).next();
if addr.is_none() {
return Err(Error::new(ErrorKind::InvalidInput,
"no addresses to connect to", None));
}
let saddr = try!(addr.unwrap().into_sockaddr_in());
let res = match saddr {
SockaddrIn::V4(s) => unsafe {
let len = mem::size_of_val(&s) as socklen_t;
let addrp = Box::new(s);
connect(fd, mem::transmute(&*addrp), len)
},
SockaddrIn::V6(s) => unsafe {
let len = mem::size_of_val(&s) as socklen_t;
let addrp = Box::new(s);
connect(fd, mem::transmute(&*addrp), len)
},
};
if res == 0 {
Ok(ConnectedSocket { sock: self })
} else {
Err(Error::new(ErrorKind::Other,
"error calling connect()", None))
}
}
}
impl<S: AsRawFd+?Sized> Read for ConnectedSocket<S> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize,Error> {
let flags = 0;
let ptr = buf.as_mut_ptr() as *mut c_void;
let len = unsafe {
recv(self.as_raw_fd(), ptr, buf.len() as u64, flags)
};
match len {
-1 => {
match errno() {
EAGAIN => Err(Error::new(ErrorKind::Interrupted, "EAGAIN", None)),
_ => Err(Error::new(ErrorKind::Other,
"recv() returned -1",
Some(os::error_string(os::errno() as i32)))),
}
},
0 => Err(Error::new(ErrorKind::Other,
"connection is closed", None)),
_ => Ok(len as usize),
}
}
}
impl<S: AsRawFd+?Sized> Write for ConnectedSocket<S> {
fn write(&mut self, buf: &[u8]) -> Result<usize,Error> {
let flags = 0;
let ptr = buf.as_ptr() as *const c_void;
let res = unsafe {
send(self.as_raw_fd(), ptr, buf.len() as u64, flags)
};
if res == (buf.len() as i64) {
Ok(res as usize)
} else {
Err(Error::new(ErrorKind::Other,
"send() failed",
Some(os::error_string(os::errno() as i32))))
}
}
fn flush(&mut self) -> Result<(),Error> {
Ok(())
}
}
pub trait SetTimeout {
fn set_timeout(&self, timeout: Duration);
}
impl<S:AsRawFd> SetTimeout for S {
fn set_timeout(&self, timeout: Duration) {
let tv = timeval {
tv_sec: timeout.num_seconds(),
tv_usec: 0,
};
unsafe {
setsockopt(self.as_raw_fd(), SOL_SOCKET, SO_RCVTIMEO,
mem::transmute(&tv), mem::size_of_val(&tv) as socklen_t)
};
}
}
#[test]
fn connect4_works() {
let socket1 = UdpSocket::bind("127.0.0.1:34200").unwrap();
let socket2 = UdpSocket::bind("127.0.0.1:34201").unwrap();
socket1.connect("127.0.0.1:34200").unwrap();
socket2.connect("127.0.0.1:34201").unwrap();
}
#[test]
fn sendrecv_works() {
let socket1 = UdpSocket::bind("127.0.0.1:34200").unwrap();
let socket2 = UdpSocket::bind("127.0.0.1:34201").unwrap();
let mut conn1 = socket1.connect("127.0.0.1:34201").unwrap();
let mut conn2 = socket2.connect("127.0.0.1:34200").unwrap();
let send1 = [0,1,2,3];
let send2 = [9,8,7,6];
conn1.write(&send1).unwrap();
conn2.write(&send2).unwrap();
let mut recv1 = [0;4];
let mut recv2 = [0;4];
conn1.read(&mut recv1).unwrap();
conn2.read(&mut recv2).unwrap();
assert_eq!(send1, recv2);
assert_eq!(send2, recv1);
}
#[test]
fn sendrecv_respects_packet_borders() {
let socket1 = UdpSocket::bind("127.0.0.1:34202").unwrap();
let socket2 = UdpSocket::bind("127.0.0.1:34203").unwrap();
let mut conn1 = socket1.connect("127.0.0.1:34203").unwrap();
let mut conn2 = socket2.connect("127.0.0.1:34202").unwrap();
let send1 = [0,1,2,3];
let send2 = [9,8,7,6];
conn1.write(&send1).unwrap();
conn1.write(&send2).unwrap();
let mut recv1 = [0;3];
let mut recv2 = [0;3];
conn2.read(&mut recv1).unwrap();
conn2.read(&mut recv2).unwrap();
assert!(send1[0..3] == recv1[0..3]);
assert!(send2[0..3] == recv2[0..3]);
}
#[test]
fn connect6_works() {
let socket1 = UdpSocket::bind("::1:34200").unwrap();
let socket2 = UdpSocket::bind("::1:34201").unwrap();
socket1.connect("::1:34200").unwrap();
socket2.connect("::1:34201").unwrap();
}
#[test]
#[should_panic]
fn detect_invalid_ipv4() {
let s = UdpSocket::bind("127.0.0.1:34300").unwrap();
s.connect("255.255.255.255:34200").unwrap();
}
#[test]
#[should_panic]
fn detect_invalid_ipv6() {
let s = UdpSocket::bind("::1:34300").unwrap();
s.connect("1200::AB00:1234::2552:7777:1313:34300").unwrap();
}
#[test]
#[should_panic]
fn double_bind() {
let socket1 = UdpSocket::bind("127.0.0.1:34301").unwrap();
let socket2 = UdpSocket::bind("127.0.0.1:34301").unwrap();
drop(socket1);
drop(socket2);
}

View File

@ -25,7 +25,6 @@ use x509::{X509StoreContext, X509FileType, X509};
use crypto::pkey::PKey;
pub mod error;
pub mod connected_socket;
#[cfg(test)]
mod tests;
@ -377,7 +376,7 @@ impl SslContext {
let ctx = SslContext { ctx: ctx };
if method.is_dtls() {
ctx.set_read_ahead();
ctx.set_read_ahead(1);
}
Ok(ctx)
@ -424,9 +423,9 @@ impl SslContext {
}
}
pub fn set_read_ahead(&self) {
pub fn set_read_ahead(&self, m: c_long) {
unsafe {
ffi::SSL_CTX_ctrl(*self.ctx, ffi::SSL_CTRL_SET_READ_AHEAD, 1, ptr::null_mut());
ffi::SSL_CTX_set_read_ahead(*self.ctx, m);
}
}

View File

@ -14,8 +14,6 @@ use crypto::hash::Type::{SHA256};
use ssl;
use ssl::SslMethod;
use ssl::SslMethod::Sslv23;
#[cfg(feature="dtlsv1")]
use ssl::SslMethod::Dtlsv1;
use ssl::{SslContext, SslStream, VerifyCallback};
use ssl::SSL_VERIFY_PEER;
use x509::X509StoreContext;
@ -29,7 +27,10 @@ use ssl::connected_socket::Connect;
#[cfg(feature="dtlsv1")]
use std::net::UdpSocket;
const PROTOCOL:SslMethod = Sslv23;
#[cfg(feature="dtlsv1")]
use ssl::SslMethod::Dtlsv1;
#[cfg(feature="dtlsv1")]
use connected_socket::Connect;
#[cfg(test)]
mod udp {
@ -61,7 +62,8 @@ macro_rules! run_test(
use ssl;
use ssl::SslMethod;
use ssl::{SslContext, SslStream, VerifyCallback};
use ssl::connected_socket::Connect;
#[cfg(feature="dtlsv1")]
use connected_socket::Connect;
use ssl::SslVerifyMode::SSL_VERIFY_PEER;
use crypto::hash::Type::SHA256;
use x509::X509StoreContext;
@ -319,7 +321,6 @@ fn test_read() {
io::copy(&mut stream, &mut io::sink()).ok().expect("read error");
}
/// Tests that connecting with the client using NPN, but the server not does not
/// break the existing connection behavior.
#[test]