diff --git a/network/netstack/main.go b/network/netstack/main.go index 5ae3e5e..0e6867f 100755 --- a/network/netstack/main.go +++ b/network/netstack/main.go @@ -27,7 +27,7 @@ func main() { panic(err) } socket.SetupDefaultNetwork(context.Background(), tun, socket.NetworkOptions{Debug: true}) - serverFd, err := socket.Socket() + serverFd, err := socket.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP) if err != nil { panic(err) } @@ -43,7 +43,7 @@ func main() { fmt.Printf("server start success %d, listen on: %s\n", serverFd, hostAddr) for { fmt.Println("accepting...") - connFd, err := socket.Accept(serverFd) + connFd, _, err := socket.Accept(serverFd) if err != nil { log.Println(err) continue diff --git a/network/netstack/socket/net.go b/network/netstack/socket/net.go index a2ea647..0c0b757 100755 --- a/network/netstack/socket/net.go +++ b/network/netstack/socket/net.go @@ -20,7 +20,7 @@ type SockFile interface { Write(b []byte) (n int, err error) } -func Socket() (fd int, err error) { +func Socket(domain int, typ int, protocol int) (fd int, err error) { if defaultNetwork == nil { return 0, ErrNoNetwork } @@ -43,11 +43,12 @@ func Listen(fd int, backlog uint) (err error) { return defaultNetwork.listen(fd, backlog) } -func Accept(fd int) (cfd int, err error) { +func Accept(fd int) (cfd int, addr SocketAddr, err error) { if defaultNetwork == nil { - return 0, ErrNoNetwork + return 0, SocketAddr{}, ErrNoNetwork } - return defaultNetwork.accept(fd) + cfd, addr, err = defaultNetwork.accept(fd) + return cfd, addr, err } func AcceptWithTimeout(fd int, timeout time.Duration) (cfd int, err error) { @@ -239,10 +240,10 @@ func (n *Network) handle(data []byte) { tcpPack := ipPack.Payload.(*tcpip.TcpPack) sock, ok := n.getSocket( SocketAddr{ - SrcIP: ipPack.SrcIP.String(), - SrcPort: tcpPack.SrcPort, - DstIP: ipPack.DstIP.String(), - DstPort: tcpPack.DstPort, + RemoteIP: ipPack.SrcIP.String(), + RemotePort: tcpPack.SrcPort, + LocalIP: ipPack.DstIP.String(), + LocalPort: tcpPack.DstPort, }, ) if !ok { @@ -250,7 +251,7 @@ func (n *Network) handle(data []byte) { } sock.Lock() if sock.State == tcpip.TcpStateUnInitialized { - log.Printf("socket %s:%d is not initialized,drop packet", sock.localIP, sock.localPort) + log.Printf("socket %s:%d is not initialized,drop packet", sock.LocalIP, sock.LocalPort) sock.Unlock() return } @@ -258,7 +259,7 @@ func (n *Network) handle(data []byte) { select { case sock.writeCh <- ipPack: default: - log.Printf("socket %s:%d is full,drop packet", sock.localIP, sock.localPort) + log.Printf("socket %s:%d is full,drop packet", sock.LocalIP, sock.LocalPort) } } @@ -273,11 +274,11 @@ func (n *Network) bind(fd int, addr string) (err error) { if !ok { return fmt.Errorf("%w: %d", ErrNoSocket, fd) } - sock.localIP = ip.String() - sock.localPort = port + sock.LocalIP = ip.String() + sock.LocalPort = port n.bindSocket(SocketAddr{ - DstIP: ip.String(), - DstPort: port, + LocalIP: ip.String(), + LocalPort: port, }, fd) return nil } @@ -291,12 +292,13 @@ func (n *Network) listen(fd int, backlog uint) (err error) { return sock.Listen(backlog) } -func (n *Network) accept(fd int) (cfd int, err error) { +func (n *Network) accept(fd int) (cfd int, addr SocketAddr, err error) { sock, ok := n.getSocketByFd(fd) if !ok { - return 0, fmt.Errorf("%w: %d", ErrNoSocket, fd) + return 0, SocketAddr{}, fmt.Errorf("%w: %d", ErrNoSocket, fd) } - return sock.Accept() + cfd, addr, err = sock.Accept() + return cfd, addr, err } func (n *Network) acceptWithTimeout(fd int, timeout time.Duration) (cfd int, err error) { @@ -344,34 +346,25 @@ func (n *Network) connect(fd int, serverAddr string) (err error) { return fmt.Errorf("%w: %d", ErrNoSocket, fd) } var addr SocketAddr - if sock.localIP == "" && sock.localPort == 0 { + if sock.LocalIP == "" && sock.LocalPort == 0 { addr, err = n.getAvailableAddress() if err != nil { return err } - sock.localIP = addr.DstIP - sock.localPort = addr.DstPort } else { n.unbindSocket(SocketAddr{ - DstIP: sock.localIP, - DstPort: sock.localPort, + LocalIP: sock.LocalIP, + LocalPort: sock.LocalPort, }) addr = SocketAddr{ - DstIP: sock.localIP, - DstPort: sock.localPort, + LocalIP: sock.LocalIP, + LocalPort: sock.LocalPort, } } - addr.SrcIP = serverIP.String() - addr.SrcPort = serverPort + addr.RemoteIP = serverIP.String() + addr.RemotePort = serverPort n.bindSocket(addr, fd) - InitConnectSocket( - sock, - nil, - net.ParseIP(sock.localIP), - sock.localPort, - serverIP, - serverPort, - ) + InitConnectSocket(sock, nil, addr) return sock.Connect() } @@ -381,8 +374,8 @@ func (n *Network) getSocket(addr SocketAddr) (sock *TcpSocket, ok bool) { return n.getSocketByFd(value.(int)) } newAddr := SocketAddr{ - DstIP: addr.DstIP, - DstPort: addr.DstPort, + LocalIP: addr.LocalIP, + LocalPort: addr.LocalPort, } value, ok = n.socketFds.Load(newAddr) if ok { @@ -416,11 +409,9 @@ func (n *Network) getAvailableAddress() (addr SocketAddr, err error) { localIp := ip.String() var p uint16 for p = n.opt.IpLocalPortRange.Start; p <= n.opt.IpLocalPortRange.End; p++ { - if _, ok := n.socketFds.Load(SocketAddr{DstIP: localIp, DstPort: p}); !ok { - return SocketAddr{ - DstIP: localIp, - DstPort: p, - }, nil + addr := SocketAddr{LocalIP: localIp, LocalPort: p} + if _, ok := n.socketFds.Load(addr); !ok { + return addr, nil } } return SocketAddr{}, fmt.Errorf("no available address") diff --git a/network/netstack/socket/packet_builder.go b/network/netstack/socket/packet_builder.go index 73e8f85..8baf3ee 100755 --- a/network/netstack/socket/packet_builder.go +++ b/network/netstack/socket/packet_builder.go @@ -34,10 +34,10 @@ func NewPacketBuilder(opt NetworkOptions) *PacketBuilder { } func (b *PacketBuilder) SetAddr(addr SocketAddr) *PacketBuilder { - srcIP := net.ParseIP(addr.SrcIP).To4() - dstIP := net.ParseIP(addr.DstIP).To4() + srcIP := net.ParseIP(addr.LocalIP).To4() + dstIP := net.ParseIP(addr.RemoteIP).To4() if srcIP == nil || dstIP == nil { - b.err = fmt.Errorf("invalid IPv4 address: %s, %s", addr.SrcIP, addr.DstIP) + b.err = fmt.Errorf("invalid IPv4 address: %s, %s", addr.LocalIP, addr.RemoteIP) return b } b.ip.IPHeader.SrcIP = srcIP @@ -46,8 +46,8 @@ func (b *PacketBuilder) SetAddr(addr SocketAddr) *PacketBuilder { SrcIP: b.ip.IPHeader.SrcIP, DstIP: b.ip.IPHeader.DstIP, } - b.tcp.TcpHeader.SrcPort = addr.SrcPort - b.tcp.TcpHeader.DstPort = addr.DstPort + b.tcp.TcpHeader.SrcPort = addr.LocalPort + b.tcp.TcpHeader.DstPort = addr.RemotePort return b } diff --git a/network/netstack/socket/socket.go b/network/netstack/socket/socket.go index 8056488..ff47cbe 100755 --- a/network/netstack/socket/socket.go +++ b/network/netstack/socket/socket.go @@ -5,41 +5,38 @@ import ( "io" "log" "math/rand" - "net" "netstack/tcpip" "sync" "time" ) type SocketAddr struct { - SrcIP string - SrcPort uint16 - DstIP string - DstPort uint16 + LocalIP string + RemoteIP string + LocalPort uint16 + RemotePort uint16 } type TcpSocket struct { sync.Mutex + SocketAddr + State tcpip.TcpState + fd int - localIP string - remoteIP string - localPort uint16 - remotePort uint16 + network *Network + listener *TcpSocket - network *Network acceptQueue chan *TcpSocket synQueue sync.Map - readCh chan []byte - writeCh chan *tcpip.IPPack - listener *TcpSocket + readCh chan []byte + writeCh chan *tcpip.IPPack + recvNext uint32 sendNext uint32 sendUnack uint32 sendBuffer []byte - - State tcpip.TcpState } func NewSocket(network *Network) *TcpSocket { @@ -48,38 +45,6 @@ func NewSocket(network *Network) *TcpSocket { } } -// func NewListenSocket(network *Network) *Socket { -// return &Socket{ -// network: network, -// synQueue: sync.Map{}, -// acceptQueue: make(chan *Socket, network.opt.Backlog), -// readCh: make(chan []byte), -// writeCh: make(chan *tcpip.IPPack), -// State: tcpip.TcpStateListen, -// } -// } - -// func NewConnectSocket( -// listenSocket *Socket, -// localIP net.IP, -// localPort uint16, -// remoteIP net.IP, -// remotePort uint16, -// ) *Socket { -// return &Socket{ -// network: listenSocket.network, -// listener: listenSocket, -// localIP: localIP.String(), -// remoteIP: remoteIP.String(), -// localPort: localPort, -// remotePort: remotePort, -// State: tcpip.TcpStateClosed, -// readCh: make(chan []byte, 1024), -// writeCh: make(chan *tcpip.IPPack), -// sendBuffer: make([]byte, 1024), -// } -// } - func InitListenSocket(sock *TcpSocket) { sock.Lock() defer sock.Unlock() @@ -92,18 +57,12 @@ func InitListenSocket(sock *TcpSocket) { func InitConnectSocket( sock *TcpSocket, listenSocket *TcpSocket, - localIP net.IP, - localPort uint16, - remoteIP net.IP, - remotePort uint16, + addr SocketAddr, ) { sock.Lock() defer sock.Unlock() sock.listener = listenSocket - sock.localIP = localIP.String() - sock.remoteIP = remoteIP.String() - sock.localPort = localPort - sock.remotePort = remotePort + sock.SocketAddr = addr sock.readCh = make(chan []byte, 1024) sock.writeCh = make(chan *tcpip.IPPack) sock.sendBuffer = make([]byte, 1024) @@ -116,11 +75,11 @@ func (s *TcpSocket) Listen(backlog uint) (err error) { return nil } -func (s *TcpSocket) Accept() (cfd int, err error) { +func (s *TcpSocket) Accept() (cfd int, addr SocketAddr, err error) { cs := <-s.acceptQueue cs.Lock() defer cs.Unlock() - return cs.fd, nil + return cs.fd, cs.SocketAddr, nil } func (s *TcpSocket) AcceptWithTimeout(timeout time.Duration) (cfd int, err error) { @@ -258,10 +217,12 @@ func (s *TcpSocket) handleNewSocket(ipPack *tcpip.IPPack, tcpPack *tcpip.TcpPack InitConnectSocket( sock, s, - ipPack.DstIP, - tcpPack.DstPort, - ipPack.SrcIP, - tcpPack.SrcPort, + SocketAddr{ + LocalIP: ipPack.DstIP.String(), + LocalPort: tcpPack.DstPort, + RemoteIP: ipPack.SrcIP.String(), + RemotePort: tcpPack.SrcPort, + }, ) } sock.handle(ipPack, tcpPack) @@ -279,23 +240,20 @@ func (s *TcpSocket) handleSyn(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, err e seq = s.network.opt.Seq } - ipResp, tcpResp, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). - SetSeq(seq). - SetAck(tcpPack.SequenceNumber + 1). + s.sendUnack = seq + s.sendNext = seq + + ipResp, _, err := NewPacketBuilder(s.network.opt). + SetAddr(s.SocketAddr). + SetSeq(s.sendNext). + SetAck(s.recvNext). SetFlags(tcpip.TcpSYN | tcpip.TcpACK). Build() if err != nil { return nil, err } - s.sendUnack = tcpResp.SequenceNumber - s.sendNext = tcpResp.SequenceNumber + 1 + s.sendNext++ return ipResp, nil } @@ -318,23 +276,20 @@ func (s *TcpSocket) handleSynResp(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, e tcpPack.AckNumber, ) } + s.State = tcpip.TcpStateEstablished + s.recvNext = tcpPack.SequenceNumber + 1 + ipResp, _, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). - SetSeq(s.sendNext - 1). - SetAck(tcpPack.SequenceNumber + 1). + SetAddr(s.SocketAddr). + SetSeq(s.sendNext). + SetAck(s.recvNext). SetFlags(tcpip.TcpACK). Build() if err != nil { return nil, err } s.sendUnack++ - s.recvNext = tcpPack.SequenceNumber + 1 select { case s.listener.acceptQueue <- s: @@ -348,7 +303,7 @@ func (s *TcpSocket) handleSynResp(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, e func (s *TcpSocket) handleFirstAck(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, err error) { s.State = tcpip.TcpStateEstablished s.sendUnack = tcpPack.AckNumber - s.synQueue.Delete(s.remotePort) + s.synQueue.Delete(s.RemotePort) select { case s.listener.acceptQueue <- s: default: @@ -356,12 +311,7 @@ func (s *TcpSocket) handleFirstAck(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, } s.network.addSocket(s) - s.network.bindSocket(SocketAddr{ - SrcIP: s.remoteIP, - SrcPort: s.remotePort, - DstIP: s.localIP, - DstPort: s.localPort, - }, s.fd) + s.network.bindSocket(s.SocketAddr, s.fd) go s.runloop() return nil, nil } @@ -389,12 +339,7 @@ func (s *TcpSocket) handleData(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPack, err } ipResp, _, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpACK). @@ -410,12 +355,7 @@ func (s *TcpSocket) handleFin() (resp *tcpip.IPPack, err error) { s.recvNext += 1 s.State = tcpip.TcpStateCloseWait ipResp, _, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpACK). @@ -432,12 +372,7 @@ func (s *TcpSocket) handleFin() (resp *tcpip.IPPack, err error) { func (s *TcpSocket) handleLastAck() { s.State = tcpip.TcpStateClosed s.network.removeSocket(s.fd) - s.network.unbindSocket(SocketAddr{ - SrcIP: s.remoteIP, - SrcPort: s.remotePort, - DstIP: s.localIP, - DstPort: s.localPort, - }) + s.network.unbindSocket(s.SocketAddr) } func (s *TcpSocket) handleFinWait1( @@ -462,25 +397,20 @@ func (s *TcpSocket) handleFinWait2Fin(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPac if err != nil { return nil, fmt.Errorf("encode tcp payload failed %w", err) } - if len(data) == 0 { - return nil, nil - } + // +1 for FIN s.recvNext = s.recvNext + uint32(len(data)) + 1 - select { - case s.readCh <- data: - default: - return nil, fmt.Errorf("the reader queue is full, drop the data") + if len(data) > 0 { + select { + case s.readCh <- data: + default: + return nil, fmt.Errorf("the reader queue is full, drop the data") + } } ipResp, _, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpACK). @@ -491,12 +421,8 @@ func (s *TcpSocket) handleFinWait2Fin(tcpPack *tcpip.TcpPack) (resp *tcpip.IPPac s.State = tcpip.TcpStateClosed s.network.removeSocket(s.fd) - s.network.unbindSocket(SocketAddr{ - SrcIP: s.remoteIP, - SrcPort: s.remotePort, - DstIP: s.localIP, - DstPort: s.localPort, - }) + s.network.unbindSocket(s.SocketAddr) + close(s.readCh) return ipResp, nil } @@ -550,12 +476,7 @@ func (s *TcpSocket) passiveCloseSocket() (ipResp *tcpip.IPPack, err error) { s.State = tcpip.TcpStateLastAck ipResp, tcpResp, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpFIN | tcpip.TcpACK). @@ -574,12 +495,7 @@ func (s *TcpSocket) activeCloseSocket() (ipResp *tcpip.IPPack, err error) { s.State = tcpip.TcpStateFinWait1 ipResp, tcpResp, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpFIN | tcpip.TcpACK). @@ -603,19 +519,13 @@ func (s *TcpSocket) handleSend(data []byte) (send int, resp *tcpip.IPPack, err e return 0, nil, nil } - log.Println("handle send data", len(data), string(data)) send = s.cacheSendData(data) if send == 0 { return 0, nil, nil } ipResp, _, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). + SetAddr(s.SocketAddr). SetSeq(s.sendNext). SetAck(s.recvNext). SetFlags(tcpip.TcpACK). @@ -644,7 +554,7 @@ func (s *TcpSocket) checkSeqAck(tcpPack *tcpip.TcpPack) (valid bool) { if s.sendUnack == s.sendNext { return tcpPack.AckNumber == s.sendNext } - return tcpPack.AckNumber >= s.sendUnack && tcpPack.AckNumber <= s.sendNext + return tcpPack.AckNumber > s.sendUnack && tcpPack.AckNumber <= s.sendNext } func (s *TcpSocket) cacheSendData(data []byte) int { @@ -697,22 +607,21 @@ func (s *TcpSocket) activeConnect() (ipResp *tcpip.IPPack, err error) { } else { seq = s.network.opt.Seq } - ipResp, tcpResp, err := NewPacketBuilder(s.network.opt). - SetAddr(SocketAddr{ - SrcIP: s.localIP, - SrcPort: s.localPort, - DstIP: s.remoteIP, - DstPort: s.remotePort, - }). - SetSeq(seq). + + s.sendUnack = seq + s.sendNext = seq + + ipResp, _, err = NewPacketBuilder(s.network.opt). + SetAddr(s.SocketAddr). + SetSeq(s.sendNext). SetFlags(tcpip.TcpSYN). Build() if err != nil { return nil, err } - s.sendUnack = tcpResp.SequenceNumber - s.sendNext = tcpResp.SequenceNumber + 1 + s.sendNext++ + s.listener = s return ipResp, nil diff --git a/network/netstack/socket/socket_it_test.go b/network/netstack/socket/socket_it_test.go index 526686c..1d56701 100644 --- a/network/netstack/socket/socket_it_test.go +++ b/network/netstack/socket/socket_it_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "golang.org/x/sys/unix" ) func EnsureNetwork() (err error) { @@ -31,7 +32,7 @@ func NewServer(hostAddr string, handler tcpHandler) (server *Server, err error) if err := EnsureNetwork(); err != nil { return nil, err } - serverFd, err := Socket() + serverFd, err := Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP) if err != nil { return nil, err } @@ -155,7 +156,7 @@ func TestActiveConnection(t *testing.T) { l.Close() }() - cfd, err := Socket() + cfd, err := Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP) assert.NoError(t, err) assert.NoError(t, Connect(cfd, serverAddr)) diff --git a/network/netstack/socket/socket_test.go b/network/netstack/socket/socket_test.go index 4969744..c62f357 100755 --- a/network/netstack/socket/socket_test.go +++ b/network/netstack/socket/socket_test.go @@ -44,10 +44,12 @@ func TestSocketServer(t *testing.T) { InitConnectSocket( connectSock, listenSock, - net.ParseIP(args.dstIp), - args.dstPort, - net.ParseIP(args.srcIp), - args.srcPort, + SocketAddr{ + RemoteIP: args.srcIp, + RemotePort: args.srcPort, + LocalIP: args.dstIp, + LocalPort: args.dstPort, + }, ) client := endpoint{ip: args.srcIp, port: args.srcPort, t: t} server := endpoint{ip: args.dstIp, port: args.dstPort, t: t}