From dec9fb0bd379b806aa1f5b3d9b949d4d2dbc8d41 Mon Sep 17 00:00:00 2001 From: Quentin Dufournet Date: Wed, 24 Sep 2025 13:17:06 +0200 Subject: [PATCH] fix: protect concurrent map access in port-forward Add synchronization around shared maps used for port forwarding. Introduced a read/write mutex for the recvChans map in portforward.go and for the mutexAck map in portforward_tcp.go. All map reads and writes are now wrapped with sync.RWMutex locks to prevent data races and intermittent panics under concurrent usage. Changelog: None Ticket: None Signed-off-by: Quentin Dufournet --- cmd/portforward.go | 49 ++++++++++++++++++++++++------------ cmd/portforward_tcp.go | 57 +++++++++++++++++++++++++++--------------- 2 files changed, 70 insertions(+), 36 deletions(-) diff --git a/cmd/portforward.go b/cmd/portforward.go index b20b3f3b..3e0263f1 100644 --- a/cmd/portforward.go +++ b/cmd/portforward.go @@ -21,6 +21,7 @@ import ( "os/signal" "strconv" "strings" + "sync" "time" "github.com/mendersoftware/go-lib-micro/ws" @@ -51,9 +52,9 @@ var portForwardCmd = &cobra.Command{ "it possible to port-forward to third hosts running in the device's network.\n" + "In this case, the specification will be LOCAL_PORT:REMOTE_HOST:REMOTE_PORT.\n\n" + "You can specify multiple port mapping specifications.", - Example: " mender-cli port-forward DEVICE_ID 8000:8000\n" + - " mender-cli port-forward DEVICE_ID udp/8000:8000\n" + - " mender-cli port-forward DEVICE_ID tcp/8000:192.168.1.1:8000", + Example: " mender-cli port-forward DEVICE_ID 8000:8000\n" + + " mender-cli port-forward DEVICE_ID udp/8000:8000\n" + + " mender-cli port-forward DEVICE_ID tcp/8000:192.168.1.1:8000", Args: cobra.MinimumNArgs(2), Run: func(c *cobra.Command, args []string) { cmd, err := NewPortForwardCmd(c, args) @@ -63,10 +64,10 @@ var portForwardCmd = &cobra.Command{ } var portForwardMaxDuration = 24 * time.Hour - var errPortForwardNotImplemented = errors.New( "port forward not implemented or enabled on the device", ) + var errRestart = errors.New("restart") func init() { @@ -87,17 +88,18 @@ type portMapping struct { // PortForwardCmd handles the port-forward command type PortForwardCmd struct { - server string - token string - skipVerify bool - deviceID string - sessionID string - bindingHost string - portMappings []portMapping - recvChans map[string]chan *ws.ProtoMsg - running bool - stop chan struct{} - err error + server string + token string + skipVerify bool + deviceID string + sessionID string + bindingHost string + portMappings []portMapping + recvChans map[string]chan *ws.ProtoMsg + recvChansLock sync.RWMutex + running bool + stop chan struct{} + err error } func getPortMappings(args []string) ([]portMapping, error) { @@ -106,6 +108,7 @@ func getPortMappings(args []string) ([]portMapping, error) { for _, arg := range args { remoteHost := localhost protocol := wspf.PortForwardProtocolTCP + if strings.Contains(arg, "/") { parts := strings.SplitN(arg, "/", 2) if parts[0] == protocolTCP { @@ -117,6 +120,7 @@ func getPortMappings(args []string) ([]portMapping, error) { } arg = parts[1] } + var localPort, remotePort int if strings.Contains(arg, ":") { parts := strings.SplitN(arg, ":", 3) @@ -124,10 +128,12 @@ func getPortMappings(args []string) ([]portMapping, error) { remoteHost = parts[1] parts = []string{parts[0], parts[2]} } + localPort, err = strconv.Atoi(parts[0]) if err != nil || localPort < 0 || localPort > 65536 { return nil, errors.New("invalid port number: " + parts[0]) } + remotePort, err = strconv.Atoi(parts[1]) if err != nil || remotePort < 0 || remotePort > 65536 { return nil, errors.New("invalid port number: " + parts[1]) @@ -140,6 +146,7 @@ func getPortMappings(args []string) ([]portMapping, error) { localPort = port remotePort = port } + portMappings = append(portMappings, portMapping{ Protocol: protocol, LocalPort: uint16(localPort), @@ -239,6 +246,7 @@ func (c *PortForwardCmd) run() error { if err != nil { return err } + go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans) case protocolUDP: forwarder, err := NewUDPPortForwarder(c.bindingHost, portMapping.LocalPort, @@ -246,6 +254,7 @@ func (c *PortForwardCmd) run() error { if err != nil { return err } + go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans) default: return errors.New("unknown protocol: " + portMapping.Protocol) @@ -313,6 +322,7 @@ func (c *PortForwardCmd) handshake(client *deviceconnect.Client) error { if err != nil { return err } + m := &ws.ProtoMsg{ Header: ws.ProtoHdr{ Proto: ws.ProtoTypeControl, @@ -320,6 +330,7 @@ func (c *PortForwardCmd) handshake(client *deviceconnect.Client) error { }, Body: body, } + err = client.WriteMessage(m) if err != nil { return err @@ -329,6 +340,7 @@ func (c *PortForwardCmd) handshake(client *deviceconnect.Client) error { if err != nil { return err } + if msg.Header.MsgType == ws.MessageTypeError { erro := new(ws.Error) _ = msgpack.Unmarshal(msg.Body, erro) @@ -366,6 +378,7 @@ func (c *PortForwardCmd) closeSession(client *deviceconnect.Client) error { MsgType: ws.MessageTypeClose, }, } + err := client.WriteMessage(m) if err != nil { return err @@ -411,7 +424,11 @@ func (c *PortForwardCmd) processIncomingMessages( m.Header.MsgType == wspf.MessageTypePortForwardStop) { connectionID, _ := m.Header.Properties[wspf.PropertyConnectionID].(string) if connectionID != "" { - if recvChan, ok := c.recvChans[connectionID]; ok { + c.recvChansLock.RLock() + recvChan, ok := c.recvChans[connectionID] + c.recvChansLock.RUnlock() + + if ok { recvChan <- m } } diff --git a/cmd/portforward_tcp.go b/cmd/portforward_tcp.go index 963d1eab..4cc055d0 100644 --- a/cmd/portforward_tcp.go +++ b/cmd/portforward_tcp.go @@ -34,10 +34,11 @@ import ( const portForwardTCPChannelSize = 20 type TCPPortForwarder struct { - listen net.Listener - remoteHost string - remotePort uint16 - mutexAck map[string]*sync.Mutex + listen net.Listener + remoteHost string + remotePort uint16 + mutexAck map[string]*sync.Mutex + mutexAckLock sync.RWMutex } func NewTCPPortForwarder( @@ -51,6 +52,7 @@ func NewTCPPortForwarder( if err != nil { return nil, err } + return &TCPPortForwarder{ listen: listen, remoteHost: remoteHost, @@ -76,6 +78,7 @@ func (p *TCPPortForwarder) Run( if err != nil { return } + fmt.Printf( "Handling connection from %s to %s\n", conn.RemoteAddr().String(), @@ -110,9 +113,14 @@ func (p *TCPPortForwarder) handleRequest( ) { defer conn.Close() + p.mutexAckLock.Lock() p.mutexAck[connectionID] = &sync.Mutex{} + p.mutexAckLock.Unlock() + defer func() { + p.mutexAckLock.Lock() delete(p.mutexAck, connectionID) + p.mutexAckLock.Unlock() }() errChan := make(chan error) @@ -124,11 +132,13 @@ func (p *TCPPortForwarder) handleRequest( RemoteHost: &p.remoteHost, RemotePort: &p.remotePort, } + body, err := msgpack.Marshal(portforwardNew) if err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) panic(err) } + m := &ws.ProtoMsg{ Header: ws.ProtoHdr{ Proto: ws.ProtoTypePortForward, @@ -140,9 +150,10 @@ func (p *TCPPortForwarder) handleRequest( }, Body: body, } - msgChan <- m + msgChan <- m sendStopMessage := true + defer func() { conn.Close() if sendStopMessage { @@ -195,9 +206,11 @@ func (p *TCPPortForwarder) handleRequest( } } else if m.Header.Proto == ws.ProtoTypePortForward && m.Header.MsgType == wspf.MessageTypePortForwardAck { - if m, ok := p.mutexAck[connectionID]; ok { - m.Unlock() + p.mutexAckLock.RLock() + if mutex, ok := p.mutexAck[connectionID]; ok { + mutex.Unlock() } + p.mutexAckLock.RUnlock() } case <-ctx.Done(): return @@ -214,21 +227,25 @@ func (p *TCPPortForwarder) handleRequest( } return case data := <-dataChan: - // lock the ack mutex, we don't allow more than one in-flight message - p.mutexAck[connectionID].Lock() + p.mutexAckLock.RLock() + mutex, ok := p.mutexAck[connectionID] + p.mutexAckLock.RUnlock() - m := &ws.ProtoMsg{ - Header: ws.ProtoHdr{ - Proto: ws.ProtoTypePortForward, - MsgType: wspf.MessageTypePortForward, - SessionID: sessionID, - Properties: map[string]interface{}{ - wspf.PropertyConnectionID: connectionID, + if ok { + mutex.Lock() + m := &ws.ProtoMsg{ + Header: ws.ProtoHdr{ + Proto: ws.ProtoTypePortForward, + MsgType: wspf.MessageTypePortForward, + SessionID: sessionID, + Properties: map[string]interface{}{ + wspf.PropertyConnectionID: connectionID, + }, }, - }, - Body: data, + Body: data, + } + msgChan <- m } - msgChan <- m case <-ctx.Done(): return } @@ -241,13 +258,13 @@ func (p *TCPPortForwarder) handleRequestConnection( conn net.Conn, ) { data := make([]byte, readBuffLength) - for { n, err := conn.Read(data) if err != nil { errChan <- err break } + if n > 0 { tmp := make([]byte, n) copy(tmp, data[:n])