Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions network/netstack/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func main() {
panic(err)
}
socket.SetupDefaultNetwork(context.Background(), tun, socket.NetworkOptions{Debug: true})
serverFd, err := socket.Socket()
serverFd, err := socket.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP)
if err != nil {
panic(err)
}
Expand All @@ -43,7 +43,7 @@ func main() {
fmt.Printf("server start success %d, listen on: %s\n", serverFd, hostAddr)
for {
fmt.Println("accepting...")
connFd, err := socket.Accept(serverFd)
connFd, _, err := socket.Accept(serverFd)
if err != nil {
log.Println(err)
continue
Expand Down
73 changes: 32 additions & 41 deletions network/netstack/socket/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type SockFile interface {
Write(b []byte) (n int, err error)
}

func Socket() (fd int, err error) {
func Socket(domain int, typ int, protocol int) (fd int, err error) {
if defaultNetwork == nil {
return 0, ErrNoNetwork
}
Expand All @@ -43,11 +43,12 @@ func Listen(fd int, backlog uint) (err error) {
return defaultNetwork.listen(fd, backlog)
}

func Accept(fd int) (cfd int, err error) {
func Accept(fd int) (cfd int, addr SocketAddr, err error) {
if defaultNetwork == nil {
return 0, ErrNoNetwork
return 0, SocketAddr{}, ErrNoNetwork
}
return defaultNetwork.accept(fd)
cfd, addr, err = defaultNetwork.accept(fd)
return cfd, addr, err
}

func AcceptWithTimeout(fd int, timeout time.Duration) (cfd int, err error) {
Expand Down Expand Up @@ -239,26 +240,26 @@ func (n *Network) handle(data []byte) {
tcpPack := ipPack.Payload.(*tcpip.TcpPack)
sock, ok := n.getSocket(
SocketAddr{
SrcIP: ipPack.SrcIP.String(),
SrcPort: tcpPack.SrcPort,
DstIP: ipPack.DstIP.String(),
DstPort: tcpPack.DstPort,
RemoteIP: ipPack.SrcIP.String(),
RemotePort: tcpPack.SrcPort,
LocalIP: ipPack.DstIP.String(),
LocalPort: tcpPack.DstPort,
},
)
if !ok {
return
}
sock.Lock()
if sock.State == tcpip.TcpStateUnInitialized {
log.Printf("socket %s:%d is not initialized,drop packet", sock.localIP, sock.localPort)
log.Printf("socket %s:%d is not initialized,drop packet", sock.LocalIP, sock.LocalPort)
sock.Unlock()
return
}
sock.Unlock()
select {
case sock.writeCh <- ipPack:
default:
log.Printf("socket %s:%d is full,drop packet", sock.localIP, sock.localPort)
log.Printf("socket %s:%d is full,drop packet", sock.LocalIP, sock.LocalPort)
}
}

Expand All @@ -273,11 +274,11 @@ func (n *Network) bind(fd int, addr string) (err error) {
if !ok {
return fmt.Errorf("%w: %d", ErrNoSocket, fd)
}
sock.localIP = ip.String()
sock.localPort = port
sock.LocalIP = ip.String()
sock.LocalPort = port
n.bindSocket(SocketAddr{
DstIP: ip.String(),
DstPort: port,
LocalIP: ip.String(),
LocalPort: port,
}, fd)
return nil
}
Expand All @@ -291,12 +292,13 @@ func (n *Network) listen(fd int, backlog uint) (err error) {
return sock.Listen(backlog)
}

func (n *Network) accept(fd int) (cfd int, err error) {
func (n *Network) accept(fd int) (cfd int, addr SocketAddr, err error) {
sock, ok := n.getSocketByFd(fd)
if !ok {
return 0, fmt.Errorf("%w: %d", ErrNoSocket, fd)
return 0, SocketAddr{}, fmt.Errorf("%w: %d", ErrNoSocket, fd)
}
return sock.Accept()
cfd, addr, err = sock.Accept()
return cfd, addr, err
}

func (n *Network) acceptWithTimeout(fd int, timeout time.Duration) (cfd int, err error) {
Expand Down Expand Up @@ -344,34 +346,25 @@ func (n *Network) connect(fd int, serverAddr string) (err error) {
return fmt.Errorf("%w: %d", ErrNoSocket, fd)
}
var addr SocketAddr
if sock.localIP == "" && sock.localPort == 0 {
if sock.LocalIP == "" && sock.LocalPort == 0 {
addr, err = n.getAvailableAddress()
if err != nil {
return err
}
sock.localIP = addr.DstIP
sock.localPort = addr.DstPort
} else {
n.unbindSocket(SocketAddr{
DstIP: sock.localIP,
DstPort: sock.localPort,
LocalIP: sock.LocalIP,
LocalPort: sock.LocalPort,
})
addr = SocketAddr{
DstIP: sock.localIP,
DstPort: sock.localPort,
LocalIP: sock.LocalIP,
LocalPort: sock.LocalPort,
}
}
addr.SrcIP = serverIP.String()
addr.SrcPort = serverPort
addr.RemoteIP = serverIP.String()
addr.RemotePort = serverPort
n.bindSocket(addr, fd)
InitConnectSocket(
sock,
nil,
net.ParseIP(sock.localIP),
sock.localPort,
serverIP,
serverPort,
)
InitConnectSocket(sock, nil, addr)
return sock.Connect()
}

Expand All @@ -381,8 +374,8 @@ func (n *Network) getSocket(addr SocketAddr) (sock *TcpSocket, ok bool) {
return n.getSocketByFd(value.(int))
}
newAddr := SocketAddr{
DstIP: addr.DstIP,
DstPort: addr.DstPort,
LocalIP: addr.LocalIP,
LocalPort: addr.LocalPort,
}
value, ok = n.socketFds.Load(newAddr)
if ok {
Expand Down Expand Up @@ -416,11 +409,9 @@ func (n *Network) getAvailableAddress() (addr SocketAddr, err error) {
localIp := ip.String()
var p uint16
for p = n.opt.IpLocalPortRange.Start; p <= n.opt.IpLocalPortRange.End; p++ {
if _, ok := n.socketFds.Load(SocketAddr{DstIP: localIp, DstPort: p}); !ok {
return SocketAddr{
DstIP: localIp,
DstPort: p,
}, nil
addr := SocketAddr{LocalIP: localIp, LocalPort: p}
if _, ok := n.socketFds.Load(addr); !ok {
return addr, nil
}
}
return SocketAddr{}, fmt.Errorf("no available address")
Expand Down
10 changes: 5 additions & 5 deletions network/netstack/socket/packet_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ func NewPacketBuilder(opt NetworkOptions) *PacketBuilder {
}

func (b *PacketBuilder) SetAddr(addr SocketAddr) *PacketBuilder {
srcIP := net.ParseIP(addr.SrcIP).To4()
dstIP := net.ParseIP(addr.DstIP).To4()
srcIP := net.ParseIP(addr.LocalIP).To4()
dstIP := net.ParseIP(addr.RemoteIP).To4()
if srcIP == nil || dstIP == nil {
b.err = fmt.Errorf("invalid IPv4 address: %s, %s", addr.SrcIP, addr.DstIP)
b.err = fmt.Errorf("invalid IPv4 address: %s, %s", addr.LocalIP, addr.RemoteIP)
return b
}
b.ip.IPHeader.SrcIP = srcIP
Expand All @@ -46,8 +46,8 @@ func (b *PacketBuilder) SetAddr(addr SocketAddr) *PacketBuilder {
SrcIP: b.ip.IPHeader.SrcIP,
DstIP: b.ip.IPHeader.DstIP,
}
b.tcp.TcpHeader.SrcPort = addr.SrcPort
b.tcp.TcpHeader.DstPort = addr.DstPort
b.tcp.TcpHeader.SrcPort = addr.LocalPort
b.tcp.TcpHeader.DstPort = addr.RemotePort
return b
}

Expand Down
Loading
Loading