From b90e468da75be163d9f66167c46f3a8eb1004160 Mon Sep 17 00:00:00 2001 From: qianz Date: Wed, 19 Feb 2025 10:20:00 +0800 Subject: [PATCH 1/8] refactor:add socket parameters --- network/netstack/main.go | 2 +- network/netstack/socket/net.go | 5 +++-- network/netstack/socket/socket_it_test.go | 5 +++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/network/netstack/main.go b/network/netstack/main.go index 5ae3e5e..9ebb9a4 100755 --- a/network/netstack/main.go +++ b/network/netstack/main.go @@ -27,7 +27,7 @@ func main() { panic(err) } socket.SetupDefaultNetwork(context.Background(), tun, socket.NetworkOptions{Debug: true}) - serverFd, err := socket.Socket() + serverFd, err := socket.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP) if err != nil { panic(err) } diff --git a/network/netstack/socket/net.go b/network/netstack/socket/net.go index a2ea647..08570df 100755 --- a/network/netstack/socket/net.go +++ b/network/netstack/socket/net.go @@ -20,7 +20,7 @@ type SockFile interface { Write(b []byte) (n int, err error) } -func Socket() (fd int, err error) { +func Socket(domain int, typ int, protocol int) (fd int, err error) { if defaultNetwork == nil { return 0, ErrNoNetwork } @@ -47,7 +47,8 @@ func Accept(fd int) (cfd int, err error) { if defaultNetwork == nil { return 0, ErrNoNetwork } - return defaultNetwork.accept(fd) + cfd, err = defaultNetwork.accept(fd) + return cfd, err } func AcceptWithTimeout(fd int, timeout time.Duration) (cfd int, err error) { diff --git a/network/netstack/socket/socket_it_test.go b/network/netstack/socket/socket_it_test.go index 526686c..1d56701 100644 --- a/network/netstack/socket/socket_it_test.go +++ b/network/netstack/socket/socket_it_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "golang.org/x/sys/unix" ) func EnsureNetwork() (err error) { @@ -31,7 +32,7 @@ func NewServer(hostAddr string, handler tcpHandler) (server *Server, err error) if err := EnsureNetwork(); err != nil { return nil, err } - serverFd, err := Socket() + serverFd, err := Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP) if err != nil { return nil, err } @@ -155,7 +156,7 @@ func TestActiveConnection(t *testing.T) { l.Close() }() - cfd, err := Socket() + cfd, err := Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP) assert.NoError(t, err) assert.NoError(t, Connect(cfd, serverAddr)) From 5a43474917a7e61bfd604255a844fa1d3a076bb8 Mon Sep 17 00:00:00 2001 From: qianz Date: Wed, 19 Feb 2025 10:51:03 +0800 Subject: [PATCH 2/8] refactor:change SocketAddr fields --- network/netstack/socket/net.go | 57 ++++------ network/netstack/socket/packet_builder.go | 10 +- network/netstack/socket/socket.go | 124 ++++++---------------- network/netstack/socket/socket_test.go | 10 +- 4 files changed, 64 insertions(+), 137 deletions(-) diff --git a/network/netstack/socket/net.go b/network/netstack/socket/net.go index 08570df..1f2c85b 100755 --- a/network/netstack/socket/net.go +++ b/network/netstack/socket/net.go @@ -240,10 +240,10 @@ func (n *Network) handle(data []byte) { tcpPack := ipPack.Payload.(*tcpip.TcpPack) sock, ok := n.getSocket( SocketAddr{ - SrcIP: ipPack.SrcIP.String(), - SrcPort: tcpPack.SrcPort, - DstIP: ipPack.DstIP.String(), - DstPort: tcpPack.DstPort, + RemoteIP: ipPack.SrcIP.String(), + RemotePort: tcpPack.SrcPort, + LocalIP: ipPack.DstIP.String(), + LocalPort: tcpPack.DstPort, }, ) if !ok { @@ -251,7 +251,7 @@ func (n *Network) handle(data []byte) { } sock.Lock() if sock.State == tcpip.TcpStateUnInitialized { - log.Printf("socket %s:%d is not initialized,drop packet", sock.localIP, sock.localPort) + log.Printf("socket %s:%d is not initialized,drop packet", sock.LocalIP, sock.LocalPort) sock.Unlock() return } @@ -259,7 +259,7 @@ func (n *Network) handle(data []byte) { select { case sock.writeCh <- ipPack: default: - log.Printf("socket %s:%d is full,drop packet", sock.localIP, sock.localPort) + log.Printf("socket %s:%d is full,drop packet", sock.LocalIP, sock.LocalPort) } } @@ -274,11 +274,11 @@ func (n *Network) bind(fd int, addr string) (err error) { if !ok { return fmt.Errorf("%w: %d", ErrNoSocket, fd) } - sock.localIP = ip.String() - sock.localPort = port + sock.LocalIP = ip.String() + sock.LocalPort = port n.bindSocket(SocketAddr{ - DstIP: ip.String(), - DstPort: port, + LocalIP: ip.String(), + LocalPort: port, }, fd) return nil } @@ -345,34 +345,25 @@ func (n *Network) connect(fd int, serverAddr string) (err error) { return fmt.Errorf("%w: %d", ErrNoSocket, fd) } var addr SocketAddr - if sock.localIP == "" && sock.localPort == 0 { + if sock.LocalIP == "" && sock.LocalPort == 0 { addr, err = n.getAvailableAddress() if err != nil { return err } - sock.localIP = addr.DstIP - sock.localPort = addr.DstPort } else { n.unbindSocket(SocketAddr{ - DstIP: sock.localIP, - DstPort: sock.localPort, + LocalIP: sock.LocalIP, + LocalPort: sock.LocalPort, }) addr = SocketAddr{ - DstIP: sock.localIP, - DstPort: sock.localPort, + LocalIP: sock.LocalIP, + LocalPort: sock.LocalPort, } } - addr.SrcIP = serverIP.String() - addr.SrcPort = serverPort + addr.RemoteIP = serverIP.String() + addr.RemotePort = serverPort n.bindSocket(addr, fd) - InitConnectSocket( - sock, - nil, - net.ParseIP(sock.localIP), - sock.localPort, - serverIP, - serverPort, - ) + InitConnectSocket(sock, nil, addr) return sock.Connect() } @@ -382,8 +373,8 @@ func (n *Network) getSocket(addr SocketAddr) (sock *TcpSocket, ok bool) { return n.getSocketByFd(value.(int)) } newAddr := SocketAddr{ - DstIP: addr.DstIP, - DstPort: addr.DstPort, + LocalIP: addr.LocalIP, + LocalPort: addr.LocalPort, } value, ok = n.socketFds.Load(newAddr) if ok { @@ -417,11 +408,9 @@ func (n *Network) getAvailableAddress() (addr SocketAddr, err error) { localIp := ip.String() var p uint16 for p = n.opt.IpLocalPortRange.Start; p <= n.opt.IpLocalPortRange.End; p++ { - if _, ok := n.socketFds.Load(SocketAddr{DstIP: localIp, DstPort: p}); !ok { - return SocketAddr{ - DstIP: localIp, - DstPort: p, - }, nil + addr := SocketAddr{LocalIP: localIp, LocalPort: p} + if _, ok := n.socketFds.Load(addr); !ok { + return addr, nil } } return SocketAddr{}, fmt.Errorf("no available address") diff --git a/network/netstack/socket/packet_builder.go b/network/netstack/socket/packet_builder.go index 73e8f85..8baf3ee 100755 --- a/network/netstack/socket/packet_builder.go +++ b/network/netstack/socket/packet_builder.go @@ -34,10 +34,10 @@ func NewPacketBuilder(opt NetworkOptions) *PacketBuilder { } func (b *PacketBuilder) SetAddr(addr SocketAddr) *PacketBuilder { - srcIP := net.ParseIP(addr.SrcIP).To4() - dstIP := net.ParseIP(addr.DstIP).To4() + srcIP := net.ParseIP(addr.LocalIP).To4() + dstIP := net.ParseIP(addr.RemoteIP).To4() if srcIP == nil || dstIP == nil { - b.err = fmt.Errorf("invalid IPv4 address: %s, %s", addr.SrcIP, addr.DstIP) + b.err = fmt.Errorf("invalid IPv4 address: %s, %s", addr.LocalIP, addr.RemoteIP) return b } b.ip.IPHeader.SrcIP = srcIP @@ -46,8 +46,8 @@ func (b *PacketBuilder) SetAddr(addr SocketAddr) *PacketBuilder { SrcIP: b.ip.IPHeader.SrcIP, DstIP: b.ip.IPHeader.DstIP, } - b.tcp.TcpHeader.SrcPort = addr.SrcPort - b.tcp.TcpHeader.DstPort = addr.DstPort + b.tcp.TcpHeader.SrcPort = addr.LocalPort + b.tcp.TcpHeader.DstPort = addr.RemotePort return b } diff --git a/network/netstack/socket/socket.go b/network/netstack/socket/socket.go index 8056488..2e528c8 100755 --- a/network/netstack/socket/socket.go +++ b/network/netstack/socket/socket.go @@ -5,27 +5,27 @@ import ( "io" "log" "math/rand" - "net" "netstack/tcpip" "sync" "time" ) type SocketAddr struct { - SrcIP string - SrcPort uint16 - DstIP string - DstPort uint16 + LocalIP string + RemoteIP string + LocalPort uint16 + RemotePort uint16 } type TcpSocket struct { sync.Mutex + SocketAddr fd int - localIP string - remoteIP string - localPort uint16 - remotePort uint16 + // localIP string + // remoteIP string + // localPort uint16 + // remotePort uint16 network *Network acceptQueue chan *TcpSocket @@ -92,18 +92,12 @@ func InitListenSocket(sock *TcpSocket) { func InitConnectSocket( sock *TcpSocket, listenSocket *TcpSocket, - localIP net.IP, - localPort uint16, - remoteIP net.IP, - remotePort uint16, + addr SocketAddr, ) { sock.Lock() defer sock.Unlock() sock.listener = listenSocket - sock.localIP = localIP.String() - sock.remoteIP = remoteIP.String() - sock.localPort = localPort - sock.remotePort = remotePort + sock.SocketAddr = addr sock.readCh = make(chan []byte, 1024) sock.writeCh = make(chan *tcpip.IPPack) sock.sendBuffer = make([]byte, 1024) @@ -258,10 +252,12 @@ func (s *TcpSocket) handleNewSocket(ipPack *tcpip.IPPack, tcpPack *tcpip.TcpPack InitConnectSocket( sock, s, - ipPack.DstIP, - tcpPack.DstPort, - ipPack.SrcIP, - tcpPack.SrcPort, + SocketAddr{ + LocalIP: ipPack.DstIP.String(), + LocalPort: tcpPack.DstPort, + RemoteIP: ipPack.SrcIP.String(), + RemotePort: tcpPack.SrcPort, + }, ) } sock.handle(ipPack, tcpPack) @@ -280,12 +276,7 @@ func (s *TcpSocket) handleSyn(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, err e } ipResp, tcpResp, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(seq). SetAck(tcpPack.SequenceNumber + 1). SetFlags(tcpip.TcpSYN | tcpip.TcpACK). @@ -320,12 +311,7 @@ func (s *TcpSocket) handleSynResp(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, e } s.State = tcpip.TcpStateEstablished ipResp, _, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext - 1). SetAck(tcpPack.SequenceNumber + 1). SetFlags(tcpip.TcpACK). @@ -348,7 +334,7 @@ func (s *TcpSocket) handleSynResp(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, e func (s *TcpSocket) handleFirstAck(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, err error) { s.State = tcpip.TcpStateEstablished s.sendUnack = tcpPack.AckNumber - s.synQueue.Delete(s.remotePort) + s.synQueue.Delete(s.RemotePort) select { case s.listener.acceptQueue <- s: default: @@ -356,12 +342,7 @@ func (s *TcpSocket) handleFirstAck(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, } s.network.addSocket(s) - s.network.bindSocket(SocketAddr{ - SrcIP: s.remoteIP, - SrcPort: s.remotePort, - DstIP: s.localIP, - DstPort: s.localPort, - }, s.fd) + s.network.bindSocket(s.SocketAddr, s.fd) go s.runloop() return nil, nil } @@ -389,12 +370,7 @@ func (s *TcpSocket) handleData(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, err } ipResp, _, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpACK). @@ -410,12 +386,7 @@ func (s *TcpSocket) handleFin() (resp *tcpip.IPPack, err error) { s.recvNext += 1 s.State = tcpip.TcpStateCloseWait ipResp, _, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpACK). @@ -432,12 +403,7 @@ func (s *TcpSocket) handleFin() (resp *tcpip.IPPack, err error) { func (s *TcpSocket) handleLastAck() { s.State = tcpip.TcpStateClosed s.network.removeSocket(s.fd) - s.network.unbindSocket(SocketAddr{ - SrcIP: s.remoteIP, - SrcPort: s.remotePort, - DstIP: s.localIP, - DstPort: s.localPort, - }) + s.network.unbindSocket(s.SocketAddr) } func (s *TcpSocket) handleFinWait1( @@ -475,12 +441,7 @@ func (s *TcpSocket) handleFinWait2Fin(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPac } ipResp, _, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpACK). @@ -491,12 +452,7 @@ func (s *TcpSocket) handleFinWait2Fin(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPac s.State = tcpip.TcpStateClosed s.network.removeSocket(s.fd) - s.network.unbindSocket(SocketAddr{ - SrcIP: s.remoteIP, - SrcPort: s.remotePort, - DstIP: s.localIP, - DstPort: s.localPort, - }) + s.network.unbindSocket(s.SocketAddr) return ipResp, nil } @@ -550,12 +506,7 @@ func (s *TcpSocket) passiveCloseSocket() (ipResp *tcpip.IPPack, err error) { s.State = tcpip.TcpStateLastAck ipResp, tcpResp, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpFIN | tcpip.TcpACK). @@ -574,12 +525,7 @@ func (s *TcpSocket) activeCloseSocket() (ipResp *tcpip.IPPack, err error) { s.State = tcpip.TcpStateFinWait1 ipResp, tcpResp, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpFIN | tcpip.TcpACK). @@ -610,12 +556,7 @@ func (s *TcpSocket) handleSend(data []byte) (send int, resp *tcpip.IPPack, err e } ipResp, _, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpACK). @@ -698,12 +639,7 @@ func (s *TcpSocket) activeConnect() (ipResp *tcpip.IPPack, err error) { seq = s.network.opt.Seq } ipResp, tcpResp, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(seq). SetFlags(tcpip.TcpSYN). Build() diff --git a/network/netstack/socket/socket_test.go b/network/netstack/socket/socket_test.go index 4969744..c62f357 100755 --- a/network/netstack/socket/socket_test.go +++ b/network/netstack/socket/socket_test.go @@ -44,10 +44,12 @@ func TestSocketServer(t *testing.T) { InitConnectSocket( connectSock, listenSock, - net.ParseIP(args.dstIp), - args.dstPort, - net.ParseIP(args.srcIp), - args.srcPort, + SocketAddr{ + RemoteIP: args.srcIp, + RemotePort: args.srcPort, + LocalIP: args.dstIp, + LocalPort: args.dstPort, + }, ) client := endpoint{ip: args.srcIp, port: args.srcPort, t: t} server := endpoint{ip: args.dstIp, port: args.dstPort, t: t} From bda182aad77b7dc9890e73725b6d6957c436fad7 Mon Sep 17 00:00:00 2001 From: qianz Date: Wed, 19 Feb 2025 11:00:50 +0800 Subject: [PATCH 3/8] refactor:accept return SocketAddr --- network/netstack/main.go | 2 +- network/netstack/socket/net.go | 15 ++++++++------- network/netstack/socket/socket.go | 4 ++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/network/netstack/main.go b/network/netstack/main.go index 9ebb9a4..0e6867f 100755 --- a/network/netstack/main.go +++ b/network/netstack/main.go @@ -43,7 +43,7 @@ func main() { fmt.Printf("server start success %d, listen on: %s\n", serverFd, hostAddr) for { fmt.Println("accepting...") - connFd, err := socket.Accept(serverFd) + connFd, _, err := socket.Accept(serverFd) if err != nil { log.Println(err) continue diff --git a/network/netstack/socket/net.go b/network/netstack/socket/net.go index 1f2c85b..0c0b757 100755 --- a/network/netstack/socket/net.go +++ b/network/netstack/socket/net.go @@ -43,12 +43,12 @@ func Listen(fd int, backlog uint) (err error) { return defaultNetwork.listen(fd, backlog) } -func Accept(fd int) (cfd int, err error) { +func Accept(fd int) (cfd int, addr SocketAddr, err error) { if defaultNetwork == nil { - return 0, ErrNoNetwork + return 0, SocketAddr{}, ErrNoNetwork } - cfd, err = defaultNetwork.accept(fd) - return cfd, err + cfd, addr, err = defaultNetwork.accept(fd) + return cfd, addr, err } func AcceptWithTimeout(fd int, timeout time.Duration) (cfd int, err error) { @@ -292,12 +292,13 @@ func (n *Network) listen(fd int, backlog uint) (err error) { return sock.Listen(backlog) } -func (n *Network) accept(fd int) (cfd int, err error) { +func (n *Network) accept(fd int) (cfd int, addr SocketAddr, err error) { sock, ok := n.getSocketByFd(fd) if !ok { - return 0, fmt.Errorf("%w: %d", ErrNoSocket, fd) + return 0, SocketAddr{}, fmt.Errorf("%w: %d", ErrNoSocket, fd) } - return sock.Accept() + cfd, addr, err = sock.Accept() + return cfd, addr, err } func (n *Network) acceptWithTimeout(fd int, timeout time.Duration) (cfd int, err error) { diff --git a/network/netstack/socket/socket.go b/network/netstack/socket/socket.go index 2e528c8..2ee7b41 100755 --- a/network/netstack/socket/socket.go +++ b/network/netstack/socket/socket.go @@ -110,11 +110,11 @@ func (s *TcpSocket) Listen(backlog uint) (err error) { return nil } -func (s *TcpSocket) Accept() (cfd int, err error) { +func (s *TcpSocket) Accept() (cfd int, addr SocketAddr, err error) { cs := <-s.acceptQueue cs.Lock() defer cs.Unlock() - return cs.fd, nil + return cs.fd, cs.SocketAddr, nil } func (s *TcpSocket) AcceptWithTimeout(timeout time.Duration) (cfd int, err error) { From 89fd0aa21a6789cb9d85bf85bc46ddc076c05bb8 Mon Sep 17 00:00:00 2001 From: qianz Date: Wed, 19 Feb 2025 14:11:53 +0800 Subject: [PATCH 4/8] refactor:set seq with sendNext and ack with recvNext --- network/netstack/socket/socket.go | 32 +++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/network/netstack/socket/socket.go b/network/netstack/socket/socket.go index 2ee7b41..127cbd6 100755 --- a/network/netstack/socket/socket.go +++ b/network/netstack/socket/socket.go @@ -275,18 +275,20 @@ func (s *TcpSocket) handleSyn(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, err e seq = s.network.opt.Seq } - ipResp, tcpResp, err := NewPacketBuilder(s.network.opt). + s.sendUnack = seq + s.sendNext = seq + + ipResp, _, err := NewPacketBuilder(s.network.opt). SetAddr(s.SocketAddr). - SetSeq(seq). - SetAck(tcpPack.SequenceNumber + 1). + SetSeq(s.sendNext). + SetAck(s.recvNext). SetFlags(tcpip.TcpSYN | tcpip.TcpACK). Build() if err != nil { return nil, err } - s.sendUnack = tcpResp.SequenceNumber - s.sendNext = tcpResp.SequenceNumber + 1 + s.sendNext++ return ipResp, nil } @@ -309,18 +311,20 @@ func (s *TcpSocket) handleSynResp(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, e tcpPack.AckNumber, ) } + s.State = tcpip.TcpStateEstablished + s.recvNext = tcpPack.SequenceNumber + 1 + ipResp, _, err := NewPacketBuilder(s.network.opt). SetAddr(s.SocketAddr). - SetSeq(s.sendNext - 1). - SetAck(tcpPack.SequenceNumber + 1). + SetSeq(s.sendNext). + SetAck(s.recvNext). SetFlags(tcpip.TcpACK). Build() if err != nil { return nil, err } s.sendUnack++ - s.recvNext = tcpPack.SequenceNumber + 1 select { case s.listener.acceptQueue <- s: @@ -638,17 +642,21 @@ func (s *TcpSocket) activeConnect() (ipResp *tcpip.IPPack, err error) { } else { seq = s.network.opt.Seq } - ipResp, tcpResp, err := NewPacketBuilder(s.network.opt). + + s.sendUnack = seq + s.sendNext = seq + + ipResp, _, err = NewPacketBuilder(s.network.opt). SetAddr(s.SocketAddr). - SetSeq(seq). + SetSeq(s.sendNext). SetFlags(tcpip.TcpSYN). Build() if err != nil { return nil, err } - s.sendUnack = tcpResp.SequenceNumber - s.sendNext = tcpResp.SequenceNumber + 1 + s.sendNext++ + s.listener = s return ipResp, nil From 253c47fe4ca605942c5980f9f11427facf5501ec Mon Sep 17 00:00:00 2001 From: qianz Date: Wed, 19 Feb 2025 15:36:36 +0800 Subject: [PATCH 5/8] fix:ack range --- network/netstack/socket/socket.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/network/netstack/socket/socket.go b/network/netstack/socket/socket.go index 127cbd6..3aa6353 100755 --- a/network/netstack/socket/socket.go +++ b/network/netstack/socket/socket.go @@ -589,7 +589,7 @@ func (s *TcpSocket) checkSeqAck(tcpPack *tcpip.TcpPack) (valid bool) { if s.sendUnack == s.sendNext { return tcpPack.AckNumber == s.sendNext } - return tcpPack.AckNumber >= s.sendUnack && tcpPack.AckNumber <= s.sendNext + return tcpPack.AckNumber > s.sendUnack && tcpPack.AckNumber <= s.sendNext } func (s *TcpSocket) cacheSendData(data []byte) int { From 6365a837db2e19f3027e776692817eab51a01ad6 Mon Sep 17 00:00:00 2001 From: qianz Date: Wed, 19 Feb 2025 16:31:09 +0800 Subject: [PATCH 6/8] refactor:clean codes --- network/netstack/socket/socket.go | 49 +++++-------------------------- 1 file changed, 7 insertions(+), 42 deletions(-) diff --git a/network/netstack/socket/socket.go b/network/netstack/socket/socket.go index 3aa6353..2927e11 100755 --- a/network/netstack/socket/socket.go +++ b/network/netstack/socket/socket.go @@ -20,26 +20,23 @@ type SocketAddr struct { type TcpSocket struct { sync.Mutex SocketAddr + State tcpip.TcpState + fd int - // localIP string - // remoteIP string - // localPort uint16 - // remotePort uint16 + network *Network + listener *TcpSocket - network *Network acceptQueue chan *TcpSocket synQueue sync.Map - readCh chan []byte - writeCh chan *tcpip.IPPack - listener *TcpSocket + readCh chan []byte + writeCh chan *tcpip.IPPack + recvNext uint32 sendNext uint32 sendUnack uint32 sendBuffer []byte - - State tcpip.TcpState } func NewSocket(network *Network) *TcpSocket { @@ -48,38 +45,6 @@ func NewSocket(network *Network) *TcpSocket { } } -// func NewListenSocket(network *Network) *Socket { -// return &Socket{ -// network: network, -// synQueue: sync.Map{}, -// acceptQueue: make(chan *Socket, network.opt.Backlog), -// readCh: make(chan []byte), -// writeCh: make(chan *tcpip.IPPack), -// State: tcpip.TcpStateListen, -// } -// } - -// func NewConnectSocket( -// listenSocket *Socket, -// localIP net.IP, -// localPort uint16, -// remoteIP net.IP, -// remotePort uint16, -// ) *Socket { -// return &Socket{ -// network: listenSocket.network, -// listener: listenSocket, -// localIP: localIP.String(), -// remoteIP: remoteIP.String(), -// localPort: localPort, -// remotePort: remotePort, -// State: tcpip.TcpStateClosed, -// readCh: make(chan []byte, 1024), -// writeCh: make(chan *tcpip.IPPack), -// sendBuffer: make([]byte, 1024), -// } -// } - func InitListenSocket(sock *TcpSocket) { sock.Lock() defer sock.Unlock() From 0bf3c23fb26a1d1173863a6f1b272e4fe9805011 Mon Sep 17 00:00:00 2001 From: qianz Date: Thu, 20 Feb 2025 15:06:16 +0800 Subject: [PATCH 7/8] refactor:delete useless log --- network/netstack/socket/socket.go | 1 - 1 file changed, 1 deletion(-) diff --git a/network/netstack/socket/socket.go b/network/netstack/socket/socket.go index 2927e11..4248d38 100755 --- a/network/netstack/socket/socket.go +++ b/network/netstack/socket/socket.go @@ -518,7 +518,6 @@ func (s *TcpSocket) handleSend(data []byte) (send int, resp *tcpip.IPPack, err e return 0, nil, nil } - log.Println("handle send data", len(data), string(data)) send = s.cacheSendData(data) if send == 0 { return 0, nil, nil From 93f64fcea7de99859111cad68693402dad97f173 Mon Sep 17 00:00:00 2001 From: qianz Date: Thu, 20 Feb 2025 16:16:50 +0800 Subject: [PATCH 8/8] fix:send last ack --- network/netstack/socket/socket.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/network/netstack/socket/socket.go b/network/netstack/socket/socket.go index 4248d38..ff47cbe 100755 --- a/network/netstack/socket/socket.go +++ b/network/netstack/socket/socket.go @@ -397,16 +397,16 @@ func (s *TcpSocket) handleFinWait2Fin(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPac if err != nil { return nil, fmt.Errorf("encode tcp payload failed %w", err) } - if len(data) == 0 { - return nil, nil - } + // +1 for FIN s.recvNext = s.recvNext + uint32(len(data)) + 1 - select { - case s.readCh <- data: - default: - return nil, fmt.Errorf("the reader queue is full, drop the data") + if len(data) > 0 { + select { + case s.readCh <- data: + default: + return nil, fmt.Errorf("the reader queue is full, drop the data") + } } ipResp, _, err := NewPacketBuilder(s.network.opt). @@ -422,6 +422,7 @@ func (s *TcpSocket) handleFinWait2Fin(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPac s.State = tcpip.TcpStateClosed s.network.removeSocket(s.fd) s.network.unbindSocket(s.SocketAddr) + close(s.readCh) return ipResp, nil }