Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions cmd/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"os/signal"
"strconv"
"strings"
"sync"
"time"

"github.com/mendersoftware/go-lib-micro/ws"
Expand Down Expand Up @@ -51,9 +52,9 @@ var portForwardCmd = &cobra.Command{
"it possible to port-forward to third hosts running in the device's network.\n" +
"In this case, the specification will be LOCAL_PORT:REMOTE_HOST:REMOTE_PORT.\n\n" +
"You can specify multiple port mapping specifications.",
Example: " mender-cli port-forward DEVICE_ID 8000:8000\n" +
" mender-cli port-forward DEVICE_ID udp/8000:8000\n" +
" mender-cli port-forward DEVICE_ID tcp/8000:192.168.1.1:8000",
Example: " mender-cli port-forward DEVICE_ID 8000:8000\n" +
" mender-cli port-forward DEVICE_ID udp/8000:8000\n" +
" mender-cli port-forward DEVICE_ID tcp/8000:192.168.1.1:8000",
Args: cobra.MinimumNArgs(2),
Run: func(c *cobra.Command, args []string) {
cmd, err := NewPortForwardCmd(c, args)
Expand All @@ -63,10 +64,10 @@ var portForwardCmd = &cobra.Command{
}

var portForwardMaxDuration = 24 * time.Hour

var errPortForwardNotImplemented = errors.New(
"port forward not implemented or enabled on the device",
)

var errRestart = errors.New("restart")

func init() {
Expand All @@ -87,17 +88,18 @@ type portMapping struct {

// PortForwardCmd handles the port-forward command
type PortForwardCmd struct {
server string
token string
skipVerify bool
deviceID string
sessionID string
bindingHost string
portMappings []portMapping
recvChans map[string]chan *ws.ProtoMsg
running bool
stop chan struct{}
err error
server string
token string
skipVerify bool
deviceID string
sessionID string
bindingHost string
portMappings []portMapping
recvChans map[string]chan *ws.ProtoMsg
recvChansLock sync.RWMutex
running bool
stop chan struct{}
err error
}

func getPortMappings(args []string) ([]portMapping, error) {
Expand All @@ -106,6 +108,7 @@ func getPortMappings(args []string) ([]portMapping, error) {
for _, arg := range args {
remoteHost := localhost
protocol := wspf.PortForwardProtocolTCP

if strings.Contains(arg, "/") {
parts := strings.SplitN(arg, "/", 2)
if parts[0] == protocolTCP {
Expand All @@ -117,17 +120,20 @@ func getPortMappings(args []string) ([]portMapping, error) {
}
arg = parts[1]
}

var localPort, remotePort int
if strings.Contains(arg, ":") {
parts := strings.SplitN(arg, ":", 3)
if len(parts) == 3 {
remoteHost = parts[1]
parts = []string{parts[0], parts[2]}
}

localPort, err = strconv.Atoi(parts[0])
if err != nil || localPort < 0 || localPort > 65536 {
return nil, errors.New("invalid port number: " + parts[0])
}

remotePort, err = strconv.Atoi(parts[1])
if err != nil || remotePort < 0 || remotePort > 65536 {
return nil, errors.New("invalid port number: " + parts[1])
Expand All @@ -140,6 +146,7 @@ func getPortMappings(args []string) ([]portMapping, error) {
localPort = port
remotePort = port
}

portMappings = append(portMappings, portMapping{
Protocol: protocol,
LocalPort: uint16(localPort),
Expand Down Expand Up @@ -239,13 +246,15 @@ func (c *PortForwardCmd) run() error {
if err != nil {
return err
}

go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans)
case protocolUDP:
forwarder, err := NewUDPPortForwarder(c.bindingHost, portMapping.LocalPort,
portMapping.RemoteHost, portMapping.RemotePort)
if err != nil {
return err
}

go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans)
default:
return errors.New("unknown protocol: " + portMapping.Protocol)
Expand Down Expand Up @@ -313,13 +322,15 @@ func (c *PortForwardCmd) handshake(client *deviceconnect.Client) error {
if err != nil {
return err
}

m := &ws.ProtoMsg{
Header: ws.ProtoHdr{
Proto: ws.ProtoTypeControl,
MsgType: ws.MessageTypeOpen,
},
Body: body,
}

err = client.WriteMessage(m)
if err != nil {
return err
Expand All @@ -329,6 +340,7 @@ func (c *PortForwardCmd) handshake(client *deviceconnect.Client) error {
if err != nil {
return err
}

if msg.Header.MsgType == ws.MessageTypeError {
erro := new(ws.Error)
_ = msgpack.Unmarshal(msg.Body, erro)
Expand Down Expand Up @@ -366,6 +378,7 @@ func (c *PortForwardCmd) closeSession(client *deviceconnect.Client) error {
MsgType: ws.MessageTypeClose,
},
}

err := client.WriteMessage(m)
if err != nil {
return err
Expand Down Expand Up @@ -411,7 +424,11 @@ func (c *PortForwardCmd) processIncomingMessages(
m.Header.MsgType == wspf.MessageTypePortForwardStop) {
connectionID, _ := m.Header.Properties[wspf.PropertyConnectionID].(string)
if connectionID != "" {
if recvChan, ok := c.recvChans[connectionID]; ok {
c.recvChansLock.RLock()
recvChan, ok := c.recvChans[connectionID]
c.recvChansLock.RUnlock()

if ok {
recvChan <- m
}
}
Expand Down
57 changes: 37 additions & 20 deletions cmd/portforward_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ import (
const portForwardTCPChannelSize = 20

type TCPPortForwarder struct {
listen net.Listener
remoteHost string
remotePort uint16
mutexAck map[string]*sync.Mutex
listen net.Listener
remoteHost string
remotePort uint16
mutexAck map[string]*sync.Mutex
mutexAckLock sync.RWMutex
}

func NewTCPPortForwarder(
Expand All @@ -51,6 +52,7 @@ func NewTCPPortForwarder(
if err != nil {
return nil, err
}

return &TCPPortForwarder{
listen: listen,
remoteHost: remoteHost,
Expand All @@ -76,6 +78,7 @@ func (p *TCPPortForwarder) Run(
if err != nil {
return
}

fmt.Printf(
"Handling connection from %s to %s\n",
conn.RemoteAddr().String(),
Expand Down Expand Up @@ -110,9 +113,14 @@ func (p *TCPPortForwarder) handleRequest(
) {
defer conn.Close()

p.mutexAckLock.Lock()
p.mutexAck[connectionID] = &sync.Mutex{}
p.mutexAckLock.Unlock()

defer func() {
p.mutexAckLock.Lock()
delete(p.mutexAck, connectionID)
p.mutexAckLock.Unlock()
}()

errChan := make(chan error)
Expand All @@ -124,11 +132,13 @@ func (p *TCPPortForwarder) handleRequest(
RemoteHost: &p.remoteHost,
RemotePort: &p.remotePort,
}

body, err := msgpack.Marshal(portforwardNew)
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err.Error())
panic(err)
}

m := &ws.ProtoMsg{
Header: ws.ProtoHdr{
Proto: ws.ProtoTypePortForward,
Expand All @@ -140,9 +150,10 @@ func (p *TCPPortForwarder) handleRequest(
},
Body: body,
}
msgChan <- m

msgChan <- m
sendStopMessage := true

defer func() {
conn.Close()
if sendStopMessage {
Expand Down Expand Up @@ -195,9 +206,11 @@ func (p *TCPPortForwarder) handleRequest(
}
} else if m.Header.Proto == ws.ProtoTypePortForward &&
m.Header.MsgType == wspf.MessageTypePortForwardAck {
if m, ok := p.mutexAck[connectionID]; ok {
m.Unlock()
p.mutexAckLock.RLock()
if mutex, ok := p.mutexAck[connectionID]; ok {
mutex.Unlock()
}
p.mutexAckLock.RUnlock()
Comment on lines +209 to +213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you consider:

Suggested change
p.mutexAckLock.RLock()
if mutex, ok := p.mutexAck[connectionID]; ok {
mutex.Unlock()
}
p.mutexAckLock.RUnlock()
p.mutexAckLock.RLock()
mutex, ok := p.mutexAck[connectionID]
p.mutexAckLock.RUnlock()
if ok && mutex != nil {
mutex.Unlock()
}

you know, I do panic when I see mutexes in the critical section ;-)

}
case <-ctx.Done():
return
Expand All @@ -214,21 +227,25 @@ func (p *TCPPortForwarder) handleRequest(
}
return
case data := <-dataChan:
// lock the ack mutex, we don't allow more than one in-flight message
p.mutexAck[connectionID].Lock()
p.mutexAckLock.RLock()
mutex, ok := p.mutexAck[connectionID]
p.mutexAckLock.RUnlock()

m := &ws.ProtoMsg{
Header: ws.ProtoHdr{
Proto: ws.ProtoTypePortForward,
MsgType: wspf.MessageTypePortForward,
SessionID: sessionID,
Properties: map[string]interface{}{
wspf.PropertyConnectionID: connectionID,
if ok {
mutex.Lock()
m := &ws.ProtoMsg{
Header: ws.ProtoHdr{
Proto: ws.ProtoTypePortForward,
MsgType: wspf.MessageTypePortForward,
SessionID: sessionID,
Properties: map[string]interface{}{
wspf.PropertyConnectionID: connectionID,
},
},
},
Body: data,
Body: data,
}
msgChan <- m
}
msgChan <- m
case <-ctx.Done():
return
}
Expand All @@ -241,13 +258,13 @@ func (p *TCPPortForwarder) handleRequestConnection(
conn net.Conn,
) {
data := make([]byte, readBuffLength)

for {
n, err := conn.Read(data)
if err != nil {
errChan <- err
break
}

if n > 0 {
tmp := make([]byte, n)
copy(tmp, data[:n])
Expand Down