diff --git a/src/socket.rs b/src/socket.rs index 22fe858b..7a6832a3 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -285,7 +285,7 @@ impl Socket { )))] { let (socket, addr) = self.accept_raw()?; - let socket = set_common_flags(socket)?; + let socket = set_common_accept_flags(socket)?; // `set_common_flags` does not disable inheritance on Windows because `Socket::new` // unlike `accept` is able to create the socket with inheritance disabled. #[cfg(windows)] @@ -762,8 +762,8 @@ const fn set_common_type(ty: Type) -> Type { } /// Set `FD_CLOEXEC` and `NOSIGPIPE` on the `socket` for platforms that need it. -#[inline(always)] -#[allow(clippy::unnecessary_wraps)] +/// +/// Sockets created via `accept` should use `set_common_accept_flags` instead. fn set_common_flags(socket: Socket) -> io::Result { // On platforms that don't have `SOCK_CLOEXEC` use `FD_CLOEXEC`. #[cfg(all( @@ -798,6 +798,46 @@ fn set_common_flags(socket: Socket) -> io::Result { Ok(socket) } +/// Set `FD_CLOEXEC` on the `socket` for platforms that need it. +/// +/// Unlike `set_common_flags` we don't set `NOSIGPIPE` as that is inherited from +/// the listener. Furthermore, attempts to set it on a unix socket domain +/// results in an error. +#[cfg(not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + target_os = "cygwin", +)))] +fn set_common_accept_flags(socket: Socket) -> io::Result { + // On platforms that don't have `SOCK_CLOEXEC` use `FD_CLOEXEC`. + #[cfg(all( + unix, + not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "hurd", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + target_os = "espidf", + target_os = "vita", + target_os = "cygwin", + )) + ))] + socket._set_cloexec(true)?; + + Ok(socket) +} + /// A local interface specified by its index or an address assigned to it. /// /// `Index(0)` and `Address(Ipv4Addr::UNSPECIFIED)` are equivalent and indicate diff --git a/tests/socket.rs b/tests/socket.rs index d34083cc..86dc8059 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -511,7 +511,7 @@ where panic!("unexpected error: {}", io::Error::last_os_error()); } assert_eq!(length as usize, size_of::()); - assert_eq!(flags, want as _, "non-blocking option"); + assert_eq!(flags, want as _); } const DATA: &[u8] = b"hello world"; @@ -605,7 +605,7 @@ fn unix() { return; } let mut path = env::temp_dir(); - path.push("socket2"); + path.push("socket2.unix"); let _ = fs::remove_dir_all(&path); fs::create_dir_all(&path).unwrap(); path.push("unix"); @@ -631,6 +631,30 @@ fn unix() { assert_eq!(&buf[..n], DATA); } +#[test] +fn unix_accept() { + if !unix_sockets_supported() { + return; + } + let mut path = env::temp_dir(); + path.push("socket2.unix_accept"); + let _ = fs::remove_dir_all(&path); + fs::create_dir_all(&path).unwrap(); + path.push("unix_accept"); + + let listener = Socket::new(Domain::UNIX, Type::STREAM, None).unwrap(); + listener.bind(&SockAddr::unix(&path).unwrap()).unwrap(); + listener.listen(1).unwrap(); + + Socket::new(Domain::UNIX, Type::STREAM, None) + .unwrap() + .connect(&SockAddr::unix(path).unwrap()) + .unwrap(); + + let (socket, _) = listener.accept().unwrap(); + assert_common_flags(&socket, true); +} + #[test] #[cfg(all(feature = "all", any(target_os = "android", target_os = "linux")))] #[ignore = "using VSOCK family requires optional kernel support (works when enabled)"]