From 8c40ad050b26ddda0386662a57a08a937aedc22e Mon Sep 17 00:00:00 2001 From: bobpaul <90864+bobpaul@users.noreply.github.com> Date: Sat, 8 Mar 2025 22:56:04 -0500 Subject: [PATCH] Restore armon/go-socks5/socks5.go which dependabot mangled Closes #61 --- vendor/github.com/armon/go-socks5/socks5.go | 32 +++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/vendor/github.com/armon/go-socks5/socks5.go b/vendor/github.com/armon/go-socks5/socks5.go index a17be68f..2d630fb4 100644 --- a/vendor/github.com/armon/go-socks5/socks5.go +++ b/vendor/github.com/armon/go-socks5/socks5.go @@ -55,6 +55,7 @@ type Config struct { type Server struct { config *Config authMethods map[uint8]Authenticator + isIPAllowed func(net.IP) bool } // New creates a new Server and potentially returns an error @@ -93,6 +94,11 @@ func New(conf *Config) (*Server, error) { server.authMethods[a.GetCode()] = a } + // Set default IP whitelist function + server.isIPAllowed = func(ip net.IP) bool { + return true // default allow all IPs + } + return server, nil } @@ -117,11 +123,37 @@ func (s *Server) Serve(l net.Listener) error { return nil } +// SetIPWhitelist sets the function to check if a given IP is allowed +func (s *Server) SetIPWhitelist(allowedIPs []net.IP) { + s.isIPAllowed = func(ip net.IP) bool { + for _, allowedIP := range allowedIPs { + if ip.Equal(allowedIP) { + return true + } + } + return false + } +} + // ServeConn is used to serve a single connection. func (s *Server) ServeConn(conn net.Conn) error { defer conn.Close() bufConn := bufio.NewReader(conn) + // Check client IP against whitelist + clientIP, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + s.config.Logger.Printf("[ERR] socks: Failed to get client IP address: %v", err) + return err + } + ip := net.ParseIP(clientIP) + if s.isIPAllowed(ip) { + s.config.Logger.Printf("[INFO] socks: Connection from allowed IP address: %s", clientIP) + } else { + s.config.Logger.Printf("[WARN] socks: Connection from not allowed IP address: %s", clientIP) + return fmt.Errorf("connection from not allowed IP address") + } + // Read the version byte version := []byte{0} if _, err := bufConn.Read(version); err != nil {