Merge pull request #294 from alexcrichton/nonblocking-tests
Get nonblocking tests working on OSX/Windows
This commit is contained in:
commit
0b76ee3bd9
|
|
@ -33,5 +33,4 @@ libc = "0.1"
|
|||
|
||||
[dev-dependencies]
|
||||
rustc-serialize = "0.3"
|
||||
net2 = "0.2.13"
|
||||
nix = "0.4"
|
||||
net2 = "0.2.16"
|
||||
|
|
|
|||
|
|
@ -4,17 +4,21 @@ use std::fs::File;
|
|||
use std::io::prelude::*;
|
||||
use std::io::{self, BufReader};
|
||||
use std::iter;
|
||||
use std::mem;
|
||||
use std::net::{TcpStream, TcpListener, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::process::{Command, Child, Stdio, ChildStdin};
|
||||
use std::thread;
|
||||
|
||||
use net2::TcpStreamExt;
|
||||
|
||||
use crypto::hash::Type::{SHA256};
|
||||
use ssl;
|
||||
use ssl::SslMethod;
|
||||
use ssl::SslMethod::Sslv23;
|
||||
use ssl::{SslContext, SslStream, VerifyCallback};
|
||||
use ssl::SSL_VERIFY_PEER;
|
||||
use ssl::SslMethod::Sslv23;
|
||||
use ssl::SslMethod;
|
||||
use ssl::error::NonblockingSslError;
|
||||
use ssl::{SslContext, SslStream, VerifyCallback, NonblockingSslStream};
|
||||
use x509::X509StoreContext;
|
||||
use x509::X509FileType;
|
||||
use x509::X509;
|
||||
|
|
@ -29,6 +33,8 @@ use ssl::SslMethod::Sslv2;
|
|||
#[cfg(feature="dtlsv1")]
|
||||
use net2::UdpSocketExt;
|
||||
|
||||
mod select;
|
||||
|
||||
fn next_addr() -> SocketAddr {
|
||||
use std::sync::atomic::{AtomicUsize, ATOMIC_USIZE_INIT, Ordering};
|
||||
static PORT: AtomicUsize = ATOMIC_USIZE_INIT;
|
||||
|
|
@ -331,7 +337,8 @@ run_test!(verify_trusted_get_error_err, |method, stream| {
|
|||
});
|
||||
|
||||
run_test!(verify_callback_data, |method, stream| {
|
||||
fn callback(_preverify_ok: bool, x509_ctx: &X509StoreContext, node_id: &Vec<u8>) -> bool {
|
||||
fn callback(_preverify_ok: bool, x509_ctx: &X509StoreContext,
|
||||
node_id: &Vec<u8>) -> bool {
|
||||
let cert = x509_ctx.get_current_cert();
|
||||
match cert {
|
||||
None => false,
|
||||
|
|
@ -808,7 +815,8 @@ mod dtlsv1 {
|
|||
fn test_read_dtlsv1() {
|
||||
let (_s, stream) = Server::new_dtlsv1(Some("hello"));
|
||||
|
||||
let mut stream = SslStream::connect_generic(&SslContext::new(Dtlsv1).unwrap(), stream).unwrap();
|
||||
let mut stream = SslStream::connect_generic(&SslContext::new(Dtlsv1).unwrap(),
|
||||
stream).unwrap();
|
||||
let mut buf = [0u8;100];
|
||||
assert!(stream.read(&mut buf).is_ok());
|
||||
}
|
||||
|
|
@ -817,67 +825,40 @@ fn test_read_dtlsv1() {
|
|||
#[cfg(feature = "sslv2")]
|
||||
fn test_sslv2_connect_failure() {
|
||||
let (_s, tcp) = Server::new_tcp(&["-no_ssl2", "-www"]);
|
||||
SslStream::connect_generic(&SslContext::new(Sslv2).unwrap(), tcp).err().unwrap();
|
||||
SslStream::connect_generic(&SslContext::new(Sslv2).unwrap(),
|
||||
tcp).err().unwrap();
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
mod nonblocking_tests {
|
||||
extern crate nix;
|
||||
fn wait_io(stream: &NonblockingSslStream<TcpStream>,
|
||||
read: bool,
|
||||
timeout_ms: u32) -> bool {
|
||||
unsafe {
|
||||
let mut set: select::fd_set = mem::zeroed();
|
||||
select::fd_set(&mut set, stream.get_ref());
|
||||
|
||||
use std::io::Write;
|
||||
use std::net::TcpStream;
|
||||
use std::os::unix::io::AsRawFd;
|
||||
|
||||
use super::Server;
|
||||
use self::nix::sys::epoll;
|
||||
use self::nix::fcntl;
|
||||
use ssl;
|
||||
use ssl::error::NonblockingSslError;
|
||||
use ssl::SslMethod;
|
||||
use ssl::SslMethod::Sslv23;
|
||||
use ssl::{SslContext, NonblockingSslStream};
|
||||
|
||||
fn wait_io(stream: &NonblockingSslStream<TcpStream>, read: bool, timeout_ms: isize) -> bool {
|
||||
let fd = stream.as_raw_fd();
|
||||
let ep = epoll::epoll_create().unwrap();
|
||||
let event = if read {
|
||||
epoll::EpollEvent {
|
||||
events: epoll::EPOLLIN | epoll::EPOLLERR,
|
||||
data: 0,
|
||||
}
|
||||
} else {
|
||||
epoll::EpollEvent {
|
||||
events: epoll::EPOLLOUT,
|
||||
data: 0,
|
||||
}
|
||||
};
|
||||
epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlAdd, fd, &event).unwrap();
|
||||
let mut events = [event];
|
||||
let count = epoll::epoll_wait(ep, &mut events, timeout_ms).unwrap();
|
||||
epoll::epoll_ctl(ep, epoll::EpollOp::EpollCtlDel, fd, &event).unwrap();
|
||||
assert!(count <= 1);
|
||||
count == 1
|
||||
let write = if read {0 as *mut _} else {&mut set as *mut _};
|
||||
let read = if !read {0 as *mut _} else {&mut set as *mut _};
|
||||
select::select(stream.get_ref(), read, write, 0 as *mut _, timeout_ms)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_nonblocking(stream: &TcpStream) {
|
||||
let fd = stream.as_raw_fd();
|
||||
fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(fcntl::O_NONBLOCK)).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_nonblocking() {
|
||||
#[test]
|
||||
fn test_write_nonblocking() {
|
||||
let (_s, stream) = Server::new();
|
||||
make_nonblocking(&stream);
|
||||
let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
|
||||
stream.set_nonblocking(true).unwrap();
|
||||
let cx = SslContext::new(Sslv23).unwrap();
|
||||
let mut stream = NonblockingSslStream::connect(&cx, stream).unwrap();
|
||||
|
||||
let mut iterations = 0;
|
||||
loop {
|
||||
iterations += 1;
|
||||
if iterations > 7 {
|
||||
// Probably a safe assumption for the foreseeable future of openssl.
|
||||
// Probably a safe assumption for the foreseeable future of
|
||||
// openssl.
|
||||
panic!("Too many read/write round trips in handshake!!");
|
||||
}
|
||||
let result = stream.write("hello".as_bytes());
|
||||
let result = stream.write(b"hello");
|
||||
match result {
|
||||
Ok(_) => {
|
||||
break;
|
||||
|
|
@ -894,25 +875,27 @@ mod nonblocking_tests {
|
|||
}
|
||||
}
|
||||
|
||||
// Second write should succeed immediately--plenty of space in kernel buffer,
|
||||
// and handshake just completed.
|
||||
// Second write should succeed immediately--plenty of space in kernel
|
||||
// buffer, and handshake just completed.
|
||||
stream.write(" there".as_bytes()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_nonblocking() {
|
||||
#[test]
|
||||
fn test_read_nonblocking() {
|
||||
let (_s, stream) = Server::new();
|
||||
make_nonblocking(&stream);
|
||||
let mut stream = NonblockingSslStream::connect(&SslContext::new(Sslv23).unwrap(), stream).unwrap();
|
||||
stream.set_nonblocking(true).unwrap();
|
||||
let cx = SslContext::new(Sslv23).unwrap();
|
||||
let mut stream = NonblockingSslStream::connect(&cx, stream).unwrap();
|
||||
|
||||
let mut iterations = 0;
|
||||
loop {
|
||||
iterations += 1;
|
||||
if iterations > 7 {
|
||||
// Probably a safe assumption for the foreseeable future of openssl.
|
||||
// Probably a safe assumption for the foreseeable future of
|
||||
// openssl.
|
||||
panic!("Too many read/write round trips in handshake!!");
|
||||
}
|
||||
let result = stream.write("GET /\r\n\r\n".as_bytes());
|
||||
let result = stream.write(b"GET /\r\n\r\n");
|
||||
match result {
|
||||
Ok(n) => {
|
||||
assert_eq!(n, 9);
|
||||
|
|
@ -949,5 +932,4 @@ mod nonblocking_tests {
|
|||
};
|
||||
assert!(bytes_read >= 5);
|
||||
assert_eq!(&input_buffer[..5], b"HTTP/");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
use libc;
|
||||
pub use self::imp::*;
|
||||
|
||||
extern "system" {
|
||||
#[link_name = "select"]
|
||||
fn raw_select(nfds: libc::c_int,
|
||||
readfds: *mut fd_set,
|
||||
writefds: *mut fd_set,
|
||||
errorfds: *mut fd_set,
|
||||
timeout: *mut libc::timeval) -> libc::c_int;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
mod imp {
|
||||
use std::os::unix::prelude::*;
|
||||
use std::io;
|
||||
use libc;
|
||||
|
||||
const FD_SETSIZE: usize = 1024;
|
||||
|
||||
#[repr(C)]
|
||||
pub struct fd_set {
|
||||
fds_bits: [u64; FD_SETSIZE / 64]
|
||||
}
|
||||
|
||||
pub fn fd_set<F: AsRawFd>(set: &mut fd_set, f: &F) {
|
||||
let fd = f.as_raw_fd() as usize;
|
||||
set.fds_bits[fd / 64] |= 1 << (fd % 64);
|
||||
}
|
||||
|
||||
pub unsafe fn select<F: AsRawFd>(max: &F,
|
||||
read: *mut fd_set,
|
||||
write: *mut fd_set,
|
||||
error: *mut fd_set,
|
||||
timeout_ms: u32)
|
||||
-> io::Result<bool> {
|
||||
let mut timeout = libc::timeval {
|
||||
tv_sec: (timeout_ms / 1000) as libc::time_t,
|
||||
tv_usec: (timeout_ms % 1000 * 1000) as libc::suseconds_t,
|
||||
};
|
||||
let rc = super::raw_select(max.as_raw_fd() + 1, read, write, error,
|
||||
&mut timeout);
|
||||
if rc < 0 {
|
||||
Err(io::Error::last_os_error())
|
||||
} else {
|
||||
Ok(rc != 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
mod imp {
|
||||
use std::os::windows::prelude::*;
|
||||
use std::io;
|
||||
use libc::{SOCKET, c_uint, c_long, timeval};
|
||||
|
||||
const FD_SETSIZE: usize = 64;
|
||||
|
||||
#[repr(C)]
|
||||
pub struct fd_set {
|
||||
fd_count: c_uint,
|
||||
fd_array: [SOCKET; FD_SETSIZE],
|
||||
}
|
||||
|
||||
pub fn fd_set<F: AsRawSocket>(set: &mut fd_set, f: &F) {
|
||||
set.fd_array[set.fd_count as usize] = f.as_raw_socket();
|
||||
set.fd_count += 1;
|
||||
}
|
||||
|
||||
pub unsafe fn select<F: AsRawSocket>(_max: &F,
|
||||
read: *mut fd_set,
|
||||
write: *mut fd_set,
|
||||
error: *mut fd_set,
|
||||
timeout_ms: u32)
|
||||
-> io::Result<bool> {
|
||||
let mut timeout = timeval {
|
||||
tv_sec: (timeout_ms / 1000) as c_long,
|
||||
tv_usec: (timeout_ms % 1000 * 1000) as c_long,
|
||||
};
|
||||
let rc = super::raw_select(1, read, write, error, &mut timeout);
|
||||
if rc < 0 {
|
||||
Err(io::Error::last_os_error())
|
||||
} else {
|
||||
Ok(rc != 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue