Skip to content

Commit

Permalink
fix: replace SetFwmark with dialer and listener setters
Browse files Browse the repository at this point in the history
  • Loading branch information
sabify committed Aug 28, 2024
1 parent aeccec9 commit 2de5c61
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 46 deletions.
7 changes: 5 additions & 2 deletions cmd/outline-ss-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"gopkg.in/yaml.v2"

"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
onet "github.com/Jigsaw-Code/outline-ss-server/net"
"github.com/Jigsaw-Code/outline-ss-server/service"
)

Expand Down Expand Up @@ -89,9 +90,11 @@ func (s *SSServer) startPort(portNum int, fwmark uint) error {
// TODO: Register initial data metrics at zero.

tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout)
tcpHandler.SetFwmark(fwmark)
tcpHandler.SetTargetDialer(service.MakeValidatingTCPStreamDialer(onet.RequirePublicIP, fwmark))
packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m)
packetHandler.SetFwmark(fwmark)
packetHandler.SetTargetPacketListener(func() (net.PacketConn, *onet.ConnectionError) {
return service.MakeTargetPacketListener(fwmark)
})
s.ports[portNum] = port
go service.StreamServe(service.WrapStreamListener(listener.AcceptTCP), tcpHandler.Handle)
go packetHandler.Handle(port.packetConn)
Expand Down
12 changes: 2 additions & 10 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,11 @@ func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout ti
}
}

var defaultDialer = makeValidatingTCPStreamDialer(onet.RequirePublicIP, 0)
var defaultDialer = MakeValidatingTCPStreamDialer(onet.RequirePublicIP, 0)

// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management.
// Value of 0 disables fwmark (SO_MARK)
func (h *tcpHandler) SetFwmark(fwmark uint) {
h.dialer = makeValidatingTCPStreamDialer(onet.RequirePublicIP, fwmark)
}

func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator, fwmark uint) transport.StreamDialer {
func MakeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator, fwmark uint) transport.StreamDialer {
return &transport.TCPDialer{Dialer: net.Dialer{Control: func(network, address string, c syscall.RawConn) error {
if fwmark > 0 {
err := c.Control(func(fd uintptr) {
Expand All @@ -212,10 +208,6 @@ type TCPHandler interface {
Handle(ctx context.Context, conn transport.StreamConn)
// SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses.
SetTargetDialer(dialer transport.StreamDialer)

// SetFwmark sets Firewall Mark for outgoing packets
// Value of 0 disables fwmark
SetFwmark(fwmark uint)
}

func (s *tcpHandler) SetTargetDialer(dialer transport.StreamDialer) {
Expand Down
6 changes: 3 additions & 3 deletions service/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0))
handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0))
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -396,7 +396,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0))
handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0))
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -432,7 +432,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0))
handler.SetTargetDialer(MakeValidatingTCPStreamDialer(allowAll, 0))
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down
75 changes: 44 additions & 31 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ import (
onet "github.com/Jigsaw-Code/outline-ss-server/net"
)

// Type alias for creating target UDP sockets
type UDPTargetConnFunc = func() (net.PacketConn, *onet.ConnectionError)

// UDPMetrics is used to report metrics on UDP connections.
type UDPMetrics interface {
ipinfo.IPInfoMap
Expand Down Expand Up @@ -87,43 +90,66 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis
}

type packetHandler struct {
natTimeout time.Duration
ciphers CipherList
m UDPMetrics
targetIPValidator onet.TargetIPValidator
fwmark uint
natTimeout time.Duration
ciphers CipherList
m UDPMetrics
targetIPValidator onet.TargetIPValidator
targetListenerFunc UDPTargetConnFunc
}

// NewPacketHandler creates a UDPService
func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetrics) PacketHandler {
return &packetHandler{
natTimeout: natTimeout,
ciphers: cipherList, m: m,
targetIPValidator: onet.RequirePublicIP,
fwmark: 0,
targetIPValidator: onet.RequirePublicIP,
targetListenerFunc: defaultTargetListner,
}
}

var defaultTargetListner = func() (net.PacketConn, *onet.ConnectionError) {
return MakeTargetPacketListener(0)
}

// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management.
// Value of 0 disables fwmark (SO_MARK)
func MakeTargetPacketListener(fwmark uint) (net.PacketConn, *onet.ConnectionError) {
udpConn, err := net.ListenPacket("udp", "")
if err != nil {
return nil, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err)
}

if fwmark > 0 {
file, err := udpConn.(*net.UDPConn).File()
if err != nil {
return nil, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to get UDP socket file", err)
}
defer file.Close()

err = syscall.SetsockoptInt(int(file.Fd()), syscall.SOL_SOCKET, syscall.SO_MARK, int(fwmark))
if err != nil {
slog.Error("Set fwmark failed.", "err", os.NewSyscallError("failed to set mark for UDP socket", err))
}
}
return udpConn, nil
}

// PacketHandler is a running UDP shadowsocks proxy that can be stopped.
type PacketHandler interface {
// SetTargetIPValidator sets the function to be used to validate the target IP addresses.
SetTargetIPValidator(targetIPValidator onet.TargetIPValidator)
// Handle returns after clientConn closes and all the sub goroutines return.
Handle(clientConn net.PacketConn)

// SetFwmark sets Firewall Mark for outgoing packets
// Value of 0 disables fwmark
SetFwmark(fwmark uint)
SetTargetPacketListener(UDPTargetConnFunc)
}

func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) {
h.targetIPValidator = targetIPValidator
func (h *packetHandler) SetTargetPacketListener(connFunc UDPTargetConnFunc) {
h.targetListenerFunc = connFunc
}

// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management.
// Value of 0 disables fwmark (SO_MARK)
func (h *packetHandler) SetFwmark(fwmark uint) {
h.fwmark = fwmark
func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) {
h.targetIPValidator = targetIPValidator
}

// Listen on addr for encrypted packets and basically do UDP NAT.
Expand Down Expand Up @@ -190,22 +216,9 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) {
return onetErr
}

udpConn, err := net.ListenPacket("udp", "")
udpConn, err := h.targetListenerFunc()
if err != nil {
return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err)
}

if h.fwmark > 0 {
file, err := udpConn.(*net.UDPConn).File()
if err != nil {
return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to get UDP socket file", err)
}
defer file.Close()

err = syscall.SetsockoptInt(int(file.Fd()), syscall.SOL_SOCKET, syscall.SO_MARK, int(h.fwmark))
if err != nil {
slog.Error("Set fwmark failed.", "err", os.NewSyscallError("failed to set mark for UDP socket", err))
}
return err
}

targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID)
Expand Down

0 comments on commit 2de5c61

Please sign in to comment.