Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions awkernel_lib/src/net/if_net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use smoltcp::{
wire::HardwareAddress,
};

use crate::sync::{mcs::MCSNode, mutex::Mutex};
use crate::sync::{mcs::MCSNode, mutex::Mutex, rwlock::RwLock};

use super::{
ether::{extract_headers, NetworkHdr, TransportHdr, ETHER_ADDR_LEN},
Expand Down Expand Up @@ -163,6 +163,7 @@ impl Device for NetDriverRef<'_> {
pub(super) struct IfNet {
vlan: Option<u16>,
pub(super) inner: Mutex<IfNetInner>,
pub(super) socket_set: RwLock<SocketSet<'static>>,
rx_irq_to_driver: BTreeMap<u16, NetDriver>,
tx_only_ringq: Vec<Mutex<RingQ<Vec<u8>>>>,
pub(super) net_device: Arc<dyn NetDevice + Sync + Send>,
Expand All @@ -173,7 +174,6 @@ pub(super) struct IfNet {

pub(super) struct IfNetInner {
pub(super) interface: Interface,
pub(super) socket_set: SocketSet<'static>,
pub(super) default_gateway_ipv4: Option<smoltcp::wire::Ipv4Address>,

multicast_addr_ipv4: BTreeSet<Ipv4Addr>,
Expand All @@ -182,8 +182,8 @@ pub(super) struct IfNetInner {

impl IfNetInner {
#[inline(always)]
pub fn split(&mut self) -> (&mut Interface, &mut SocketSet<'static>) {
(&mut self.interface, &mut self.socket_set)
pub fn get_interface(&mut self) -> &mut Interface {
&mut self.interface
}

#[inline(always)]
Expand Down Expand Up @@ -277,11 +277,11 @@ impl IfNet {
vlan,
inner: Mutex::new(IfNetInner {
interface,
socket_set,
default_gateway_ipv4: None,
multicast_addr_ipv4: BTreeSet::new(),
multicast_addr_mac: BTreeMap::new(),
}),
socket_set: RwLock::new(socket_set),
rx_irq_to_driver,
net_device,
tx_only_ringq,
Expand Down Expand Up @@ -488,8 +488,8 @@ impl IfNet {
let mut node = MCSNode::new();
let mut inner = self.inner.lock(&mut node);

let (interface, socket_set) = inner.split();
interface.poll(timestamp, &mut device_ref, socket_set)
let interface = inner.get_interface();
interface.poll(timestamp, &mut device_ref, &self.socket_set)
};

// send packets from the queue.
Expand Down Expand Up @@ -547,9 +547,8 @@ impl IfNet {
let mut node = MCSNode::new();
let mut inner = self.inner.lock(&mut node);

let (interface, socket_set) = inner.split();

interface.poll(timestamp, &mut device_ref, socket_set)
let interface = inner.get_interface();
interface.poll(timestamp, &mut device_ref, &self.socket_set)
};

// send packets from the queue.
Expand Down
79 changes: 46 additions & 33 deletions awkernel_lib/src/net/tcp_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,7 @@ impl TcpListener {
// Create a TCP socket.
let socket = create_listen_socket(&addr, port.port(), rx_buffer_size, tx_buffer_size);

let handle = {
let mut node = MCSNode::new();
let mut if_net_inner = if_net.inner.lock(&mut node);

if_net_inner.socket_set.add(socket)
};
let handle = if_net.socket_set.write().add(socket);

handles.push(handle);
}
Expand Down Expand Up @@ -128,43 +123,54 @@ impl TcpListener {
let if_net = if_net.clone();
drop(net_manager);

let mut node = MCSNode::new();
let mut interface = if_net.inner.lock(&mut node);

for handle in self.handles.iter_mut() {
let socket: &mut smoltcp::socket::tcp::Socket = interface.socket_set.get_mut(*handle);
if socket.may_send() {
let (may_send, is_not_open) = {
let socket_set = if_net.socket_set.read();
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(*handle)
.lock(&mut node);
if socket.may_send() {
(true, false)
} else if !socket.is_open() {
(false, true)
} else {
(false, false)
}
};

if may_send {
// If the connection is established, create a new socket and add it to the interface.
let new_socket = create_listen_socket(
&self.addr,
self.port.port(),
self.rx_buffer_size,
self.tx_buffer_size,
);
let mut new_handle = interface.socket_set.add(new_socket);
let mut socket_set = if_net.socket_set.write();
let mut new_handle = socket_set.add(new_socket);

// Swap the new handle with the old handle.
core::mem::swap(handle, &mut new_handle);

// The old handle is now a connected socket.
self.connected_sockets.push_back(new_handle);
} else if !socket.is_open() {
} else if is_not_open {
// If the socket is closed, create a new socket and add it to the interface.
let new_socket = create_listen_socket(
&self.addr,
self.port.port(),
self.rx_buffer_size,
self.tx_buffer_size,
);
interface.socket_set.remove(*handle);
*handle = interface.socket_set.add(new_socket);
let mut socket_set = if_net.socket_set.write();
socket_set.remove(*handle);
*handle = socket_set.add(new_socket);
}
}

// If there is a connected socket, return it.
if let Some(handle) = self.connected_sockets.pop_front() {
drop(interface);

let port = {
let mut net_manager = NET_MANAGER.write();
if self.addr.is_ipv4() {
Expand All @@ -183,14 +189,16 @@ impl TcpListener {
}));
}

let socket_set = if_net.socket_set.read();
// Register the waker for the listening sockets.
for handle in self.handles.iter() {
let socket: &mut smoltcp::socket::tcp::Socket = interface.socket_set.get_mut(*handle);
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let mut socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(*handle)
.lock(&mut node);
socket.register_send_waker(waker);
}

drop(interface);

Ok(None)
}
}
Expand All @@ -203,23 +211,28 @@ impl Drop for TcpListener {
let if_net = if_net.clone();
drop(net_manager);

let mut node = MCSNode::new();
let mut inner = if_net.inner.lock(&mut node);
{
let socket_set = if_net.socket_set.read();

// Close listening sockets.
for handle in self.handles.iter() {
let socket: &mut smoltcp::socket::tcp::Socket = inner.socket_set.get_mut(*handle);
socket.abort();
}
// Close listening sockets.
for handle in self.handles.iter() {
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let mut socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(*handle)
.lock(&mut node);
socket.abort();
}

// Close connected sockets.
for handle in self.connected_sockets.iter() {
let socket: &mut smoltcp::socket::tcp::Socket = inner.socket_set.get_mut(*handle);
socket.abort();
// Close connected sockets.
for handle in self.connected_sockets.iter() {
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let mut socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(*handle)
.lock(&mut node);
socket.abort();
}
}

drop(inner);

let que_id = crate::cpu::raw_cpu_id() & (if_net.net_device.num_queues() - 1);
if_net.poll_tx_only(que_id);
}
Expand Down
98 changes: 65 additions & 33 deletions awkernel_lib/src/net/tcp_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,29 @@ impl Drop for TcpStream {
drop(net_manager);

{
let mut node = MCSNode::new();
let mut inner = if_net.inner.lock(&mut node);
let socket_set = if_net.socket_set.read();
let closed = {
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(self.handle)
.lock(&mut node);

let socket: &mut smoltcp::socket::tcp::Socket = inner.socket_set.get_mut(self.handle);
matches!(socket.state(), smoltcp::socket::tcp::State::Closed)
};

// If the socket is already closed, remove it from the socket set.
if matches!(socket.state(), smoltcp::socket::tcp::State::Closed) {
inner.socket_set.remove(self.handle);

if closed {
drop(socket_set);
let mut socket_set = if_net.socket_set.write();
socket_set.remove(self.handle);
return;
}

// Otherwise, close the socket.
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let mut socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(self.handle)
.lock(&mut node);
socket.close();
}

Expand Down Expand Up @@ -98,16 +108,25 @@ pub fn close_connections() {
let mut remain_v = VecDeque::new();

{
let mut node = MCSNode::new();
let mut inner = if_net.inner.lock(&mut node);

while let Some((handle, port)) = v.pop_front() {
let socket: &mut smoltcp::socket::tcp::Socket =
inner.socket_set.get_mut(handle);
if socket.state() == smoltcp::socket::tcp::State::Closed {
let socket_set = if_net.socket_set.read();
let closed = {
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(handle)
.lock(&mut node);
socket.state() == smoltcp::socket::tcp::State::Closed
};
if closed {
drop(socket_set);
let mut socket_set = if_net.socket_set.write();
// If the socket is already closed, remove it from the socket set.
inner.socket_set.remove(handle);
socket_set.remove(handle);
} else {
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let mut socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(handle)
.lock(&mut node);
socket.close();
remain_v.push_back((handle, port));
}
Expand Down Expand Up @@ -180,20 +199,27 @@ impl TcpStream {
let mut node = MCSNode::new();
let mut inner = if_net.inner.lock(&mut node);

let (interface, socket_set) = inner.split();
let interface = inner.get_interface();

let mut socket_set = if_net.socket_set.write();
handle = socket_set.add(socket);

let socket: &mut smoltcp::socket::tcp::Socket = socket_set.get_mut(handle);

if socket
.connect(
interface.context(),
(remote_addr.addr, remote_port),
local_port.port(),
)
.is_err()
{
let connect_is_err = {
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let mut socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(handle)
.lock(&mut node);

socket
.connect(
interface.context(),
(remote_addr.addr, remote_port),
local_port.port(),
)
.is_err()
};

if connect_is_err {
socket_set.remove(handle);
return Err(NetManagerError::InvalidState);
}
Expand Down Expand Up @@ -227,10 +253,12 @@ impl TcpStream {
let if_net = if_net.clone();
drop(net_manager);

let mut node = MCSNode::new();
let mut inner = if_net.inner.lock(&mut node);
let socket_set = if_net.socket_set.read();

let socket: &mut smoltcp::socket::tcp::Socket = inner.socket_set.get_mut(self.handle);
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let mut socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(self.handle)
.lock(&mut node);

if socket.state() == smoltcp::socket::tcp::State::SynSent {
socket.register_recv_waker(waker);
Expand Down Expand Up @@ -271,10 +299,12 @@ impl TcpStream {
let if_net = if_net.clone();
drop(net_manager);

let mut node = MCSNode::new();
let mut inner = if_net.inner.lock(&mut node);
let socket_set = if_net.socket_set.read();

let socket: &mut smoltcp::socket::tcp::Socket = inner.socket_set.get_mut(self.handle);
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let mut socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(self.handle)
.lock(&mut node);

if socket.state() == smoltcp::socket::tcp::State::SynSent {
socket.register_recv_waker(waker);
Expand Down Expand Up @@ -308,10 +338,12 @@ impl TcpStream {
let if_net = if_net.clone();
drop(net_manager);

let mut node = MCSNode::new();
let inner = if_net.inner.lock(&mut node);
let socket_set = if_net.socket_set.read();

let socket: &smoltcp::socket::tcp::Socket = inner.socket_set.get(self.handle);
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
let socket = socket_set
.get::<smoltcp::socket::tcp::Socket>(self.handle)
.lock(&mut node);

if let Some(endpoint) = socket.remote_endpoint() {
Ok((
Expand Down
Loading
Loading