diff --git a/handle.go b/handle.go index e00ca85..cbe16c0 100644 --- a/handle.go +++ b/handle.go @@ -80,29 +80,40 @@ func (sf *Server) handleRequest(write io.Writer, req *Request) error { return fmt.Errorf("bind to %v blocked by rules", req.RawDestAddr) } + var last Handler // Switch on the command switch req.Command { case statute.CommandConnect: + last = sf.handleConnect if sf.userConnectHandle != nil { - return sf.userConnectHandle(ctx, write, req) + last = sf.userConnectHandle + } + if len(sf.userConnectMiddlewares) != 0 { + return sf.userConnectMiddlewares.Execute(ctx, write, req, last) } - return sf.handleConnect(ctx, write, req) case statute.CommandBind: + last = sf.handleBind if sf.userBindHandle != nil { - return sf.userBindHandle(ctx, write, req) + last = sf.userBindHandle + } + if len(sf.userBindMiddlewares) != 0 { + return sf.userBindMiddlewares.Execute(ctx, write, req, last) } - return sf.handleBind(ctx, write, req) case statute.CommandAssociate: + last = sf.handleAssociate if sf.userAssociateHandle != nil { - return sf.userAssociateHandle(ctx, write, req) + last = sf.userAssociateHandle + } + if len(sf.userAssociateMiddlewares) != 0 { + return sf.userAssociateMiddlewares.Execute(ctx, write, req, last) } - return sf.handleAssociate(ctx, write, req) default: if err := SendReply(write, statute.RepCommandNotSupported, nil); err != nil { return fmt.Errorf("failed to send reply, %v", err) } return fmt.Errorf("unsupported command[%v]", req.Command) } + return last(ctx, write, req) } // handleConnect is used to handle a connect command diff --git a/option.go b/option.go index aa13487..e211e03 100644 --- a/option.go +++ b/option.go @@ -124,3 +124,46 @@ func WithAssociateHandle(h func(ctx context.Context, writer io.Writer, request * s.userAssociateHandle = h } } + +// Handler is used to handle a user's commands +type Handler func(ctx context.Context, writer io.Writer, request *Request) error + +// WithMiddleware is used to add interceptors in chain +type Middleware func(ctx context.Context, writer io.Writer, request *Request) error + +// MiddlewareChain is used to add interceptors in chain +type MiddlewareChain []Middleware + +// Execute is used to add interceptors in chain +func (m MiddlewareChain) Execute(ctx context.Context, writer io.Writer, request *Request, last Handler) error { + if len(m) == 0 { + return nil + } + for i := 0; i < len(m); i++ { + if err := m[i](ctx, writer, request); err != nil { + return err + } + } + return last(ctx, writer, request) +} + +// WithConnectMiddleware is used to add interceptors in chain +func WithConnectMiddleware(m Middleware) Option { + return func(s *Server) { + s.userConnectMiddlewares = append(s.userConnectMiddlewares, m) + } +} + +// WithBindMiddleware is used to add interceptors in chain +func WithBindMiddleware(m Middleware) Option { + return func(s *Server) { + s.userBindMiddlewares = append(s.userBindMiddlewares, m) + } +} + +// WithAssociateMiddleware is used to add interceptors in chain +func WithAssociateMiddleware(m Middleware) Option { + return func(s *Server) { + s.userAssociateMiddlewares = append(s.userAssociateMiddlewares, m) + } +} diff --git a/server.go b/server.go index fece21c..e79dbfd 100644 --- a/server.go +++ b/server.go @@ -57,6 +57,10 @@ type Server struct { userConnectHandle func(ctx context.Context, writer io.Writer, request *Request) error userBindHandle func(ctx context.Context, writer io.Writer, request *Request) error userAssociateHandle func(ctx context.Context, writer io.Writer, request *Request) error + // user's middleware + userConnectMiddlewares MiddlewareChain + userBindMiddlewares MiddlewareChain + userAssociateMiddlewares MiddlewareChain } // NewServer creates a new Server diff --git a/server_test.go b/server_test.go index 6b206a9..6b93058 100644 --- a/server_test.go +++ b/server_test.go @@ -20,198 +20,629 @@ import ( ) func TestSOCKS5_Connect(t *testing.T) { - // Create a local listener - l, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) + t.Run("connect", func(t *testing.T) { + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) - go func() { - conn, err := l.Accept() + go func() { + conn, err := l.Accept() + require.NoError(t, err) + defer conn.Close() + + buf := make([]byte, 4) + _, err = io.ReadAtLeast(conn, buf, 4) + require.NoError(t, err) + assert.Equal(t, []byte("ping"), buf) + + conn.Write([]byte("pong")) //nolint: errcheck + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Create a socks server with UserPass auth. + cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} + srv := NewServer( + WithAuthMethods([]Authenticator{cator}), + WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), + WithDialAndRequest(func(ctx context.Context, network, addr string, request *Request) (net.Conn, error) { + require.Equal(t, network, "tcp") + require.Equal(t, addr, lAddr.String()) + return net.Dial(network, addr) + }), + ) + + // Start listening + go func() { + err := srv.ListenAndServe("tcp", "127.0.0.1:12365") + require.NoError(t, err) + }() + time.Sleep(10 * time.Millisecond) + + // Get a local conn + conn, err := net.Dial("tcp", "127.0.0.1:12365") require.NoError(t, err) - defer conn.Close() - buf := make([]byte, 4) - _, err = io.ReadAtLeast(conn, buf, 4) + // Connect, auth and connec to local + req := bytes.NewBuffer( + []byte{ + statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth, // methods + statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r', // userpass auth + }) + reqHead := statute.Request{ + Version: statute.VersionSocks5, + Command: statute.CommandConnect, + Reserved: 0, + DstAddr: statute.AddrSpec{ + FQDN: "", + IP: net.ParseIP("127.0.0.1"), + Port: lAddr.Port, + AddrType: statute.ATYPIPv4, + }, + } + req.Write(reqHead.Bytes()) + // Send a ping + req.WriteString("ping") + + // Send all the bytes + conn.Write(req.Bytes()) //nolint: errcheck + + // Verify response + expected := []byte{ + statute.VersionSocks5, statute.MethodUserPassAuth, // response use UserPass auth + statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success + } + rspHead := statute.Request{ + Version: statute.VersionSocks5, + Command: statute.RepSuccess, + Reserved: 0, + DstAddr: statute.AddrSpec{ + FQDN: "", + IP: net.ParseIP("127.0.0.1"), + Port: 0, + AddrType: statute.ATYPIPv4, + }, + } + expected = append(expected, rspHead.Bytes()...) + expected = append(expected, []byte("pong")...) + + out := make([]byte, len(expected)) + conn.SetDeadline(time.Now().Add(time.Second)) //nolint: errcheck + _, err = io.ReadFull(conn, out) + conn.SetDeadline(time.Time{}) //nolint: errcheck + require.NoError(t, err) + // Ignore the port + out[12] = 0 + out[13] = 0 + assert.Equal(t, expected, out) + }) + + t.Run("connect/customerHandler", func(t *testing.T) { + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) - assert.Equal(t, []byte("ping"), buf) - conn.Write([]byte("pong")) //nolint: errcheck - }() - lAddr := l.Addr().(*net.TCPAddr) + go func() { + conn, err := l.Accept() + require.NoError(t, err) + defer conn.Close() + + buf := make([]byte, 4) + _, err = io.ReadAtLeast(conn, buf, 4) + require.NoError(t, err) + assert.Equal(t, []byte("ping"), buf) + + conn.Write([]byte("pong")) //nolint: errcheck + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Create a socks server with UserPass auth. + cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} + srv := NewServer( + WithAuthMethods([]Authenticator{cator}), + WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), + WithDialAndRequest(func(ctx context.Context, network, addr string, request *Request) (net.Conn, error) { + require.Equal(t, network, "tcp") + require.Equal(t, addr, lAddr.String()) + return net.Dial(network, addr) + }), + WithConnectHandle(func(ctx context.Context, writer io.Writer, request *Request) error { + rsp := statute.Reply{ + Version: statute.VersionSocks5, + Response: 0x00, + BndAddr: statute.AddrSpec{ + FQDN: "", + IP: net.ParseIP("127.0.0.1"), + Port: 0, + AddrType: statute.ATYPIPv4, + }, + } + _, err := writer.Write(rsp.Bytes()) + writer.Write([]byte("gotcha!")) + if w, ok := writer.(closeWriter); ok { + w.CloseWrite() + } + return err + }), + ) + + // Start listening + go func() { + err := srv.ListenAndServe("tcp", "127.0.0.1:12369") + require.NoError(t, err) + }() + time.Sleep(10 * time.Millisecond) + + // Get a local conn + conn, err := net.Dial("tcp", "127.0.0.1:12369") + require.NoError(t, err) - // Create a socks server with UserPass auth. - cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} - srv := NewServer( - WithAuthMethods([]Authenticator{cator}), - WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), - WithDialAndRequest(func(ctx context.Context, network, addr string, request *Request) (net.Conn, error) { - require.Equal(t, network, "tcp") - require.Equal(t, addr, lAddr.String()) - return net.Dial(network, addr) - }), - ) + // Connect, auth and connec to local + req := bytes.NewBuffer( + []byte{ + statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth, // methods + statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r', // userpass auth + }) + reqHead := statute.Request{ + Version: statute.VersionSocks5, + Command: statute.CommandConnect, + Reserved: 0, + DstAddr: statute.AddrSpec{ + FQDN: "", + IP: net.ParseIP("127.0.0.1"), + Port: lAddr.Port, + AddrType: statute.ATYPIPv4, + }, + } + req.Write(reqHead.Bytes()) + // Send a ping + req.WriteString("ping") - // Start listening - go func() { - err := srv.ListenAndServe("tcp", "127.0.0.1:12365") + // Send all the bytes + conn.Write(req.Bytes()) //nolint: errcheck + + // Verify response + expected := []byte{ + statute.VersionSocks5, statute.MethodUserPassAuth, // response use UserPass auth + statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success + } + rspHead := statute.Request{ + Version: statute.VersionSocks5, + Command: statute.RepSuccess, + Reserved: 0, + DstAddr: statute.AddrSpec{ + FQDN: "", + IP: net.ParseIP("127.0.0.1"), + Port: 0, + AddrType: statute.ATYPIPv4, + }, + } + expected = append(expected, rspHead.Bytes()...) + expected = append(expected, []byte("gotcha!")...) + + out := make([]byte, len(expected)) + conn.SetDeadline(time.Now().Add(time.Second)) //nolint: errcheck + _, err = io.ReadFull(conn, out) + conn.SetDeadline(time.Time{}) //nolint: errcheck require.NoError(t, err) - }() - time.Sleep(10 * time.Millisecond) + // Ignore the port + out[12] = 0 + out[13] = 0 + assert.Equal(t, expected, out) + }) - // Get a local conn - conn, err := net.Dial("tcp", "127.0.0.1:12365") - require.NoError(t, err) + t.Run("connect/withMiddleware", func(t *testing.T) { + var middlewareCalled bool - // Connect, auth and connec to local - req := bytes.NewBuffer( - []byte{ - statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth, // methods - statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r', // userpass auth - }) - reqHead := statute.Request{ - Version: statute.VersionSocks5, - Command: statute.CommandConnect, - Reserved: 0, - DstAddr: statute.AddrSpec{ - FQDN: "", - IP: net.ParseIP("127.0.0.1"), - Port: lAddr.Port, - AddrType: statute.ATYPIPv4, - }, - } - req.Write(reqHead.Bytes()) - // Send a ping - req.WriteString("ping") + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) - // Send all the bytes - conn.Write(req.Bytes()) //nolint: errcheck + go func() { + conn, err := l.Accept() + require.NoError(t, err) + defer conn.Close() + + buf := make([]byte, 4) + _, err = io.ReadAtLeast(conn, buf, 4) + require.NoError(t, err) + assert.Equal(t, []byte("ping"), buf) + + conn.Write([]byte("pong")) //nolint: errcheck + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Create a socks server with UserPass auth. + cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} + srv := NewServer( + WithAuthMethods([]Authenticator{cator}), + WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), + WithDialAndRequest(func(ctx context.Context, network, addr string, request *Request) (net.Conn, error) { + require.Equal(t, network, "tcp") + require.Equal(t, addr, lAddr.String()) + return net.Dial(network, addr) + }), + WithConnectMiddleware(func(ctx context.Context, writer io.Writer, request *Request) error { + middlewareCalled = true + require.Equal(t, request.LocalAddr.String(), `127.0.0.1:12366`) + return nil + }), + ) + + // Start listening + go func() { + err := srv.ListenAndServe("tcp", "127.0.0.1:12366") + require.NoError(t, err) + }() + time.Sleep(10 * time.Millisecond) + + // Get a local conn + conn, err := net.Dial("tcp", "127.0.0.1:12366") + require.NoError(t, err) - // Verify response - expected := []byte{ - statute.VersionSocks5, statute.MethodUserPassAuth, // response use UserPass auth - statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success - } - rspHead := statute.Request{ - Version: statute.VersionSocks5, - Command: statute.RepSuccess, - Reserved: 0, - DstAddr: statute.AddrSpec{ - FQDN: "", - IP: net.ParseIP("127.0.0.1"), - Port: 0, - AddrType: statute.ATYPIPv4, - }, - } - expected = append(expected, rspHead.Bytes()...) - expected = append(expected, []byte("pong")...) + // Connect, auth and connec to local + req := bytes.NewBuffer( + []byte{ + statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth, // methods + statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r', // userpass auth + }) + reqHead := statute.Request{ + Version: statute.VersionSocks5, + Command: statute.CommandConnect, + Reserved: 0, + DstAddr: statute.AddrSpec{ + FQDN: "", + IP: net.ParseIP("127.0.0.1"), + Port: lAddr.Port, + AddrType: statute.ATYPIPv4, + }, + } + req.Write(reqHead.Bytes()) + // Send a ping + req.WriteString("ping") + + // Send all the bytes + conn.Write(req.Bytes()) //nolint: errcheck + + // Verify response + expected := []byte{ + statute.VersionSocks5, statute.MethodUserPassAuth, // response use UserPass auth + statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success + } + rspHead := statute.Request{ + Version: statute.VersionSocks5, + Command: statute.RepSuccess, + Reserved: 0, + DstAddr: statute.AddrSpec{ + FQDN: "", + IP: net.ParseIP("127.0.0.1"), + Port: 0, + AddrType: statute.ATYPIPv4, + }, + } + expected = append(expected, rspHead.Bytes()...) + expected = append(expected, []byte("pong")...) + + out := make([]byte, len(expected)) + conn.SetDeadline(time.Now().Add(time.Second)) //nolint: errcheck + _, err = io.ReadFull(conn, out) + conn.SetDeadline(time.Time{}) //nolint: errcheck + require.NoError(t, err) + // Ignore the port + out[12] = 0 + out[13] = 0 + assert.Equal(t, expected, out) + assert.True(t, middlewareCalled, "middleware not called") + }) + + t.Run("connect/withMiddlewareError", func(t *testing.T) { + var middlewareCalled bool + + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + go func() { + conn, err := l.Accept() + require.NoError(t, err) + defer conn.Close() + + buf := make([]byte, 4) + _, err = io.ReadAtLeast(conn, buf, 4) + require.NoError(t, err) + assert.Equal(t, []byte("ping"), buf) + + conn.Write([]byte("pong")) //nolint: errcheck + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Create a socks server with UserPass auth. + cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} + srv := NewServer( + WithAuthMethods([]Authenticator{cator}), + WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), + WithDialAndRequest(func(ctx context.Context, network, addr string, request *Request) (net.Conn, error) { + require.Equal(t, network, "tcp") + require.Equal(t, addr, lAddr.String()) + return net.Dial(network, addr) + }), + WithConnectMiddleware(func(ctx context.Context, writer io.Writer, request *Request) error { + middlewareCalled = true + require.Equal(t, request.LocalAddr.String(), `127.0.0.1:12367`) + return errors.New("Address is blocked!") + }), + ) + + // Start listening + go func() { + err := srv.ListenAndServe("tcp", "127.0.0.1:12367") + require.NoError(t, err) + }() + time.Sleep(10 * time.Millisecond) + + // Get a local conn + conn, err := net.Dial("tcp", "127.0.0.1:12367") + require.NoError(t, err) + + // Connect, auth and connec to local + req := bytes.NewBuffer( + []byte{ + statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth, // methods + statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r', // userpass auth + }) + reqHead := statute.Request{ + Version: statute.VersionSocks5, + Command: statute.CommandConnect, + Reserved: 0, + DstAddr: statute.AddrSpec{ + FQDN: "", + IP: net.ParseIP("127.0.0.1"), + Port: lAddr.Port, + AddrType: statute.ATYPIPv4, + }, + } + req.Write(reqHead.Bytes()) + // Send a ping + req.WriteString("ping") + + // Send all the bytes + conn.Write(req.Bytes()) //nolint: errcheck + + // Verify response + expected := []byte{ + statute.VersionSocks5, statute.MethodUserPassAuth, // response use UserPass auth + statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success + } + rspHead := statute.Request{ + Version: statute.VersionSocks5, + Command: statute.RepSuccess, + Reserved: 0, + DstAddr: statute.AddrSpec{ + FQDN: "", + IP: net.ParseIP("127.0.0.1"), + Port: 0, + AddrType: statute.ATYPIPv4, + }, + } + expected = append(expected, rspHead.Bytes()...) + expected = append(expected, []byte("pong")...) + + out := make([]byte, len(expected)) + conn.SetDeadline(time.Now().Add(time.Second)) //nolint: errcheck + _, err = io.ReadFull(conn, out) + conn.SetDeadline(time.Time{}) //nolint: errcheck + require.ErrorIs(t, err, io.ErrUnexpectedEOF) + assert.Equal(t, []byte{0x5, 0x2, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, out) + assert.True(t, middlewareCalled, "middleware not called") + }) - out := make([]byte, len(expected)) - conn.SetDeadline(time.Now().Add(time.Second)) //nolint: errcheck - _, err = io.ReadFull(conn, out) - conn.SetDeadline(time.Time{}) //nolint: errcheck - require.NoError(t, err) - // Ignore the port - out[12] = 0 - out[13] = 0 - assert.Equal(t, expected, out) } func TestSOCKS5_Associate(t *testing.T) { - locIP := net.ParseIP("127.0.0.1") - // Create a local listener - serverAddr := &net.UDPAddr{IP: locIP, Port: 12399} - server, err := net.ListenUDP("udp", serverAddr) - require.NoError(t, err) - defer server.Close() - - go func() { - buf := make([]byte, 2048) - for { - n, remote, err := server.ReadFrom(buf) - if err != nil { - return + t.Run("associate", func(t *testing.T) { + locIP := net.ParseIP("127.0.0.1") + // Create a local listener + serverAddr := &net.UDPAddr{IP: locIP, Port: 12399} + server, err := net.ListenUDP("udp", serverAddr) + require.NoError(t, err) + defer server.Close() + + go func() { + buf := make([]byte, 2048) + for { + n, remote, err := server.ReadFrom(buf) + if err != nil { + return + } + require.Equal(t, []byte("ping"), buf[:n]) + + server.WriteTo([]byte("pong"), remote) //nolint: errcheck } - require.Equal(t, []byte("ping"), buf[:n]) + }() + + clientAddr := &net.UDPAddr{IP: locIP, Port: 12499} + client, err := net.ListenUDP("udp", clientAddr) + require.NoError(t, err) + defer client.Close() + + // Create a socks server + cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} + proxySrv := NewServer( + WithAuthMethods([]Authenticator{cator}), + WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), + ) + // Start listening + go func() { + err := proxySrv.ListenAndServe("tcp", "127.0.0.1:12355") + require.NoError(t, err) + }() + time.Sleep(10 * time.Millisecond) + + // Get a local conn + conn, err := net.Dial("tcp", "127.0.0.1:12355") + require.NoError(t, err) - server.WriteTo([]byte("pong"), remote) //nolint: errcheck + // Connect, auth and connec to local + req := bytes.NewBuffer( + []byte{ + statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth, + statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r', + }) + reqHead := statute.Request{ + Version: statute.VersionSocks5, + Command: statute.CommandAssociate, + Reserved: 0, + DstAddr: statute.AddrSpec{ + FQDN: "", + IP: clientAddr.IP, + Port: clientAddr.Port, + AddrType: statute.ATYPIPv4, + }, + } + req.Write(reqHead.Bytes()) + // Send all the bytes + conn.Write(req.Bytes()) //nolint: errcheck + + // Verify response + expected := []byte{ + statute.VersionSocks5, statute.MethodUserPassAuth, // use user password auth + statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success } - }() - clientAddr := &net.UDPAddr{IP: locIP, Port: 12499} - client, err := net.ListenUDP("udp", clientAddr) - require.NoError(t, err) - defer client.Close() + out := make([]byte, len(expected)) + conn.SetDeadline(time.Now().Add(time.Second)) //nolint: errcheck + _, err = io.ReadFull(conn, out) + conn.SetDeadline(time.Time{}) //nolint: errcheck + require.NoError(t, err) + require.Equal(t, expected, out) - // Create a socks server - cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} - proxySrv := NewServer( - WithAuthMethods([]Authenticator{cator}), - WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), - ) - // Start listening - go func() { - err := proxySrv.ListenAndServe("tcp", "127.0.0.1:12355") + rspHead, err := statute.ParseReply(conn) require.NoError(t, err) - }() - time.Sleep(10 * time.Millisecond) + require.Equal(t, statute.VersionSocks5, rspHead.Version) + require.Equal(t, statute.RepSuccess, rspHead.Response) + + ipByte := []byte(serverAddr.IP.To4()) + portByte := make([]byte, 2) + binary.BigEndian.PutUint16(portByte, uint16(serverAddr.Port)) + + msgBytes := []byte{0, 0, 0, statute.ATYPIPv4} + msgBytes = append(msgBytes, ipByte...) + msgBytes = append(msgBytes, portByte...) + msgBytes = append(msgBytes, []byte("ping")...) + client.WriteTo(msgBytes, &net.UDPAddr{IP: locIP, Port: rspHead.BndAddr.Port}) //nolint: errcheck + // t.Logf("proxy bind listen port: %d", rspHead.BndAddr.Port) + response := make([]byte, 1024) + n, _, err := client.ReadFrom(response) + require.NoError(t, err) + assert.Equal(t, []byte("pong"), response[n-4:n]) + time.Sleep(time.Second * 1) + }) - // Get a local conn - conn, err := net.Dial("tcp", "127.0.0.1:12355") - require.NoError(t, err) + t.Run("associate/withMiddleware", func(t *testing.T) { + var middlewareCalled bool - // Connect, auth and connec to local - req := bytes.NewBuffer( - []byte{ - statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth, - statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r', - }) - reqHead := statute.Request{ - Version: statute.VersionSocks5, - Command: statute.CommandAssociate, - Reserved: 0, - DstAddr: statute.AddrSpec{ - FQDN: "", - IP: clientAddr.IP, - Port: clientAddr.Port, - AddrType: statute.ATYPIPv4, - }, - } - req.Write(reqHead.Bytes()) - // Send all the bytes - conn.Write(req.Bytes()) //nolint: errcheck - - // Verify response - expected := []byte{ - statute.VersionSocks5, statute.MethodUserPassAuth, // use user password auth - statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success - } + locIP := net.ParseIP("127.0.0.1") + // Create a local listener + serverAddr := &net.UDPAddr{IP: locIP, Port: 12399} + server, err := net.ListenUDP("udp", serverAddr) + require.NoError(t, err) + defer server.Close() + + go func() { + buf := make([]byte, 2048) + for { + n, remote, err := server.ReadFrom(buf) + if err != nil { + return + } + require.Equal(t, []byte("ping"), buf[:n]) + + server.WriteTo([]byte("pong"), remote) //nolint: errcheck + } + }() - out := make([]byte, len(expected)) - conn.SetDeadline(time.Now().Add(time.Second)) //nolint: errcheck - _, err = io.ReadFull(conn, out) - conn.SetDeadline(time.Time{}) //nolint: errcheck - require.NoError(t, err) - require.Equal(t, expected, out) + clientAddr := &net.UDPAddr{IP: locIP, Port: 12499} + client, err := net.ListenUDP("udp", clientAddr) + require.NoError(t, err) + defer client.Close() + + // Create a socks server + cator := UserPassAuthenticator{StaticCredentials{"foo": "bar"}} + proxySrv := NewServer( + WithAuthMethods([]Authenticator{cator}), + WithLogger(NewLogger(log.New(os.Stdout, "socks5: ", log.LstdFlags))), + WithAssociateMiddleware(func(ctx context.Context, writer io.Writer, request *Request) error { + require.Equal(t, request.DestAddr.Port, 12499) + middlewareCalled = true + return nil + }), + ) + // Start listening + go func() { + err := proxySrv.ListenAndServe("tcp", "127.0.0.1:12356") + require.NoError(t, err) + }() + time.Sleep(10 * time.Millisecond) + + // Get a local conn + conn, err := net.Dial("tcp", "127.0.0.1:12356") + require.NoError(t, err) - rspHead, err := statute.ParseReply(conn) - require.NoError(t, err) - require.Equal(t, statute.VersionSocks5, rspHead.Version) - require.Equal(t, statute.RepSuccess, rspHead.Response) - - ipByte := []byte(serverAddr.IP.To4()) - portByte := make([]byte, 2) - binary.BigEndian.PutUint16(portByte, uint16(serverAddr.Port)) - - msgBytes := []byte{0, 0, 0, statute.ATYPIPv4} - msgBytes = append(msgBytes, ipByte...) - msgBytes = append(msgBytes, portByte...) - msgBytes = append(msgBytes, []byte("ping")...) - client.WriteTo(msgBytes, &net.UDPAddr{IP: locIP, Port: rspHead.BndAddr.Port}) //nolint: errcheck - // t.Logf("proxy bind listen port: %d", rspHead.BndAddr.Port) - response := make([]byte, 1024) - n, _, err := client.ReadFrom(response) - require.NoError(t, err) - assert.Equal(t, []byte("pong"), response[n-4:n]) - time.Sleep(time.Second * 1) + // Connect, auth and connec to local + req := bytes.NewBuffer( + []byte{ + statute.VersionSocks5, 2, statute.MethodNoAuth, statute.MethodUserPassAuth, + statute.UserPassAuthVersion, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r', + }) + reqHead := statute.Request{ + Version: statute.VersionSocks5, + Command: statute.CommandAssociate, + Reserved: 0, + DstAddr: statute.AddrSpec{ + FQDN: "", + IP: clientAddr.IP, + Port: clientAddr.Port, + AddrType: statute.ATYPIPv4, + }, + } + req.Write(reqHead.Bytes()) + // Send all the bytes + conn.Write(req.Bytes()) //nolint: errcheck + + // Verify response + expected := []byte{ + statute.VersionSocks5, statute.MethodUserPassAuth, // use user password auth + statute.UserPassAuthVersion, statute.AuthSuccess, // response auth success + } + + out := make([]byte, len(expected)) + conn.SetDeadline(time.Now().Add(time.Second)) //nolint: errcheck + _, err = io.ReadFull(conn, out) + conn.SetDeadline(time.Time{}) //nolint: errcheck + require.NoError(t, err) + require.Equal(t, expected, out) + + rspHead, err := statute.ParseReply(conn) + require.NoError(t, err) + require.Equal(t, statute.VersionSocks5, rspHead.Version) + require.Equal(t, statute.RepSuccess, rspHead.Response) + + ipByte := []byte(serverAddr.IP.To4()) + portByte := make([]byte, 2) + binary.BigEndian.PutUint16(portByte, uint16(serverAddr.Port)) + + msgBytes := []byte{0, 0, 0, statute.ATYPIPv4} + msgBytes = append(msgBytes, ipByte...) + msgBytes = append(msgBytes, portByte...) + msgBytes = append(msgBytes, []byte("ping")...) + client.WriteTo(msgBytes, &net.UDPAddr{IP: locIP, Port: rspHead.BndAddr.Port}) //nolint: errcheck + // t.Logf("proxy bind listen port: %d", rspHead.BndAddr.Port) + response := make([]byte, 1024) + n, _, err := client.ReadFrom(response) + require.NoError(t, err) + assert.Equal(t, []byte("pong"), response[n-4:n]) + assert.True(t, middlewareCalled, "middleware not called") + time.Sleep(time.Second * 1) + }) } func Test_SocksWithProxy(t *testing.T) {