From 2de5c61611c0a025e87c18638afc170dbacaa87f Mon Sep 17 00:00:00 2001 From: Saber Haj Rabiee Date: Wed, 28 Aug 2024 00:55:38 -0700 Subject: [PATCH] fix: replace SetFwmark with dialer and listener setters --- cmd/outline-ss-server/main.go | 7 +++- service/tcp.go | 12 +----- service/tcp_test.go | 6 +-- service/udp.go | 75 ++++++++++++++++++++--------------- 4 files changed, 54 insertions(+), 46 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 6e893b5f..ca0bdb61 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -34,6 +34,7 @@ import ( "gopkg.in/yaml.v2" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" + onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/Jigsaw-Code/outline-ss-server/service" ) @@ -89,9 +90,11 @@ func (s *SSServer) startPort(portNum int, fwmark uint) error { // TODO: Register initial data metrics at zero. tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout) - tcpHandler.SetFwmark(fwmark) + tcpHandler.SetTargetDialer(service.MakeValidatingTCPStreamDialer(onet.RequirePublicIP, fwmark)) packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m) - packetHandler.SetFwmark(fwmark) + packetHandler.SetTargetPacketListener(func() (net.PacketConn, *onet.ConnectionError) { + return service.MakeTargetPacketListener(fwmark) + }) s.ports[portNum] = port go service.StreamServe(service.WrapStreamListener(listener.AcceptTCP), tcpHandler.Handle) go packetHandler.Handle(port.packetConn) diff --git a/service/tcp.go b/service/tcp.go index 0f1efda1..efba123f 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -181,15 +181,11 @@ func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout ti } } -var defaultDialer = makeValidatingTCPStreamDialer(onet.RequirePublicIP, 0) +var defaultDialer = MakeValidatingTCPStreamDialer(onet.RequirePublicIP, 0) // fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management. // Value of 0 disables fwmark (SO_MARK) -func (h *tcpHandler) SetFwmark(fwmark uint) { - h.dialer = makeValidatingTCPStreamDialer(onet.RequirePublicIP, fwmark) -} - -func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator, fwmark uint) transport.StreamDialer { +func MakeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator, fwmark uint) transport.StreamDialer { return &transport.TCPDialer{Dialer: net.Dialer{Control: func(network, address string, c syscall.RawConn) error { if fwmark > 0 { err := c.Control(func(fd uintptr) { @@ -212,10 +208,6 @@ type TCPHandler interface { Handle(ctx context.Context, conn transport.StreamConn) // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses. SetTargetDialer(dialer transport.StreamDialer) - - // SetFwmark sets Firewall Mark for outgoing packets - // Value of 0 disables fwmark - SetFwmark(fwmark uint) } func (s *tcpHandler) SetTargetDialer(dialer transport.StreamDialer) { diff --git a/service/tcp_test.go b/service/tcp_test.go index 40a2b274..7eb27346 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -361,7 +361,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) - handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0)) + handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0)) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -396,7 +396,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) - handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0)) + handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0)) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) @@ -432,7 +432,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) - handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0)) + handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0)) done := make(chan struct{}) go func() { StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) diff --git a/service/udp.go b/service/udp.go index 4476dded..09907d82 100644 --- a/service/udp.go +++ b/service/udp.go @@ -33,6 +33,9 @@ import ( onet "github.com/Jigsaw-Code/outline-ss-server/net" ) +// Type alias for creating target UDP sockets +type UDPTargetConnFunc = func() (net.PacketConn, *onet.ConnectionError) + // UDPMetrics is used to report metrics on UDP connections. type UDPMetrics interface { ipinfo.IPInfoMap @@ -87,11 +90,11 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis } type packetHandler struct { - natTimeout time.Duration - ciphers CipherList - m UDPMetrics - targetIPValidator onet.TargetIPValidator - fwmark uint + natTimeout time.Duration + ciphers CipherList + m UDPMetrics + targetIPValidator onet.TargetIPValidator + targetListenerFunc UDPTargetConnFunc } // NewPacketHandler creates a UDPService @@ -99,11 +102,38 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr return &packetHandler{ natTimeout: natTimeout, ciphers: cipherList, m: m, - targetIPValidator: onet.RequirePublicIP, - fwmark: 0, + targetIPValidator: onet.RequirePublicIP, + targetListenerFunc: defaultTargetListner, } } +var defaultTargetListner = func() (net.PacketConn, *onet.ConnectionError) { + return MakeTargetPacketListener(0) +} + +// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management. +// Value of 0 disables fwmark (SO_MARK) +func MakeTargetPacketListener(fwmark uint) (net.PacketConn, *onet.ConnectionError) { + udpConn, err := net.ListenPacket("udp", "") + if err != nil { + return nil, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) + } + + if fwmark > 0 { + file, err := udpConn.(*net.UDPConn).File() + if err != nil { + return nil, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to get UDP socket file", err) + } + defer file.Close() + + err = syscall.SetsockoptInt(int(file.Fd()), syscall.SOL_SOCKET, syscall.SO_MARK, int(fwmark)) + if err != nil { + slog.Error("Set fwmark failed.", "err", os.NewSyscallError("failed to set mark for UDP socket", err)) + } + } + return udpConn, nil +} + // PacketHandler is a running UDP shadowsocks proxy that can be stopped. type PacketHandler interface { // SetTargetIPValidator sets the function to be used to validate the target IP addresses. @@ -111,19 +141,15 @@ type PacketHandler interface { // Handle returns after clientConn closes and all the sub goroutines return. Handle(clientConn net.PacketConn) - // SetFwmark sets Firewall Mark for outgoing packets - // Value of 0 disables fwmark - SetFwmark(fwmark uint) + SetTargetPacketListener(UDPTargetConnFunc) } -func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { - h.targetIPValidator = targetIPValidator +func (h *packetHandler) SetTargetPacketListener(connFunc UDPTargetConnFunc) { + h.targetListenerFunc = connFunc } -// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management. -// Value of 0 disables fwmark (SO_MARK) -func (h *packetHandler) SetFwmark(fwmark uint) { - h.fwmark = fwmark +func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { + h.targetIPValidator = targetIPValidator } // Listen on addr for encrypted packets and basically do UDP NAT. @@ -190,22 +216,9 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { return onetErr } - udpConn, err := net.ListenPacket("udp", "") + udpConn, err := h.targetListenerFunc() if err != nil { - return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) - } - - if h.fwmark > 0 { - file, err := udpConn.(*net.UDPConn).File() - if err != nil { - return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to get UDP socket file", err) - } - defer file.Close() - - err = syscall.SetsockoptInt(int(file.Fd()), syscall.SOL_SOCKET, syscall.SO_MARK, int(h.fwmark)) - if err != nil { - slog.Error("Set fwmark failed.", "err", os.NewSyscallError("failed to set mark for UDP socket", err)) - } + return err } targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID)