Skip to content

Commit

Permalink
fix: make setting fwmark via a dedicated method
Browse files Browse the repository at this point in the history
  • Loading branch information
sabify committed Aug 27, 2024
1 parent 60aad45 commit aeccec9
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 44 deletions.
7 changes: 5 additions & 2 deletions cmd/outline-ss-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ func (s *SSServer) startPort(portNum int, fwmark uint) error {
port := &ssPort{tcpListener: listener, packetConn: packetConn, cipherList: service.NewCipherList()}
authFunc := service.NewShadowsocksStreamAuthenticator(port.cipherList, &s.replayCache, s.m)
// TODO: Register initial data metrics at zero.
tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout, fwmark)
packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m, fwmark)

tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout)
tcpHandler.SetFwmark(fwmark)
packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m)
packetHandler.SetFwmark(fwmark)
s.ports[portNum] = port
go service.StreamServe(service.WrapStreamListener(listener.AcceptTCP), tcpHandler.Handle)
go packetHandler.Handle(port.packetConn)
Expand Down
26 changes: 10 additions & 16 deletions internal/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ import (

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
logging "github.com/op/go-logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
"github.com/Jigsaw-Code/outline-ss-server/service"
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
logging "github.com/op/go-logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const maxUDPPacketSize = 64 * 1024
Expand Down Expand Up @@ -134,7 +133,7 @@ func TestTCPEcho(t *testing.T) {
const testTimeout = 200 * time.Millisecond
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -203,7 +202,7 @@ func TestRestrictedAddresses(t *testing.T) {
const testTimeout = 200 * time.Millisecond
testMetrics := &statusMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
done := make(chan struct{})
go func() {
service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -263,23 +262,18 @@ var _ service.UDPMetrics = (*fakeUDPMetrics)(nil)
func (m *fakeUDPMetrics) GetIPInfo(ip net.IP) (ipinfo.IPInfo, error) {
return ipinfo.IPInfo{CountryCode: "QQ"}, nil
}

func (m *fakeUDPMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) {
m.up = append(m.up, udpRecord{clientInfo, accessKey, status, clientProxyBytes, proxyTargetBytes})
}

func (m *fakeUDPMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) {
m.down = append(m.down, udpRecord{clientInfo, accessKey, status, targetProxyBytes, proxyClientBytes})
}

func (m *fakeUDPMetrics) AddUDPNatEntry(clientAddr net.Addr, accessKey string) {
m.natAdded++
}

func (m *fakeUDPMetrics) RemoveUDPNatEntry(clientAddr net.Addr, accessKey string) {
// Not tested because it requires waiting for a long timeout.
}

func (m *fakeUDPMetrics) AddUDPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) {}

func TestUDPEcho(t *testing.T) {
Expand All @@ -295,7 +289,7 @@ func TestUDPEcho(t *testing.T) {
t.Fatal(err)
}
testMetrics := &fakeUDPMetrics{}
proxy := service.NewPacketHandler(time.Hour, cipherList, testMetrics, 0)
proxy := service.NewPacketHandler(time.Hour, cipherList, testMetrics)
proxy.SetTargetIPValidator(allowAll)
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -390,7 +384,7 @@ func BenchmarkTCPThroughput(b *testing.B) {
const testTimeout = 200 * time.Millisecond
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -454,7 +448,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) {
const testTimeout = 200 * time.Millisecond
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -527,7 +521,7 @@ func BenchmarkUDPEcho(b *testing.B) {
if err != nil {
b.Fatal(err)
}
proxy := service.NewPacketHandler(time.Hour, cipherList, &service.NoOpUDPMetrics{}, 0)
proxy := service.NewPacketHandler(time.Hour, cipherList, &service.NoOpUDPMetrics{})
proxy.SetTargetIPValidator(allowAll)
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -571,7 +565,7 @@ func BenchmarkUDPManyKeys(b *testing.B) {
if err != nil {
b.Fatal(err)
}
proxy := service.NewPacketHandler(time.Hour, cipherList, &service.NoOpUDPMetrics{}, 0)
proxy := service.NewPacketHandler(time.Hour, cipherList, &service.NoOpUDPMetrics{})
proxy.SetTargetIPValidator(allowAll)
done := make(chan struct{})
go func() {
Expand Down
16 changes: 13 additions & 3 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,7 @@ type tcpHandler struct {
}

// NewTCPService creates a TCPService
func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration, fwmark uint) TCPHandler {
defaultDialer := makeValidatingTCPStreamDialer(onet.RequirePublicIP, fwmark)

func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler {
return &tcpHandler{
m: m,
readTimeout: timeout,
Expand All @@ -183,6 +181,14 @@ func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout ti
}
}

var defaultDialer = makeValidatingTCPStreamDialer(onet.RequirePublicIP, 0)

// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management.
// Value of 0 disables fwmark (SO_MARK)
func (h *tcpHandler) SetFwmark(fwmark uint) {
h.dialer = makeValidatingTCPStreamDialer(onet.RequirePublicIP, fwmark)
}

func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator, fwmark uint) transport.StreamDialer {
return &transport.TCPDialer{Dialer: net.Dialer{Control: func(network, address string, c syscall.RawConn) error {
if fwmark > 0 {
Expand All @@ -206,6 +212,10 @@ type TCPHandler interface {
Handle(ctx context.Context, conn transport.StreamConn)
// SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses.
SetTargetDialer(dialer transport.StreamDialer)

// SetFwmark sets Firewall Mark for outgoing packets
// Value of 0 disables fwmark
SetFwmark(fwmark uint)
}

func (s *tcpHandler) SetTargetDialer(dialer transport.StreamDialer) {
Expand Down
16 changes: 8 additions & 8 deletions service/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ func TestProbeRandom(t *testing.T) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond, 0)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -360,7 +360,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond, 0)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0))
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -395,7 +395,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond, 0)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0))
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -431,7 +431,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond, 0)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0))
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -474,7 +474,7 @@ func TestProbeServerBytesModified(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond, 0)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -505,7 +505,7 @@ func TestReplayDefense(t *testing.T) {
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
Expand Down Expand Up @@ -584,7 +584,7 @@ func TestReverseReplayDefense(t *testing.T) {
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
Expand Down Expand Up @@ -655,7 +655,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout)

done := make(chan struct{})
go func() {
Expand Down
14 changes: 12 additions & 2 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ type packetHandler struct {
}

// NewPacketHandler creates a UDPService
func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetrics, fwmark uint) PacketHandler {
func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetrics) PacketHandler {
return &packetHandler{
natTimeout: natTimeout,
ciphers: cipherList, m: m,
targetIPValidator: onet.RequirePublicIP,
fwmark: fwmark,
fwmark: 0,
}
}

Expand All @@ -110,12 +110,22 @@ type PacketHandler interface {
SetTargetIPValidator(targetIPValidator onet.TargetIPValidator)
// Handle returns after clientConn closes and all the sub goroutines return.
Handle(clientConn net.PacketConn)

// SetFwmark sets Firewall Mark for outgoing packets
// Value of 0 disables fwmark
SetFwmark(fwmark uint)
}

func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) {
h.targetIPValidator = targetIPValidator
}

// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management.
// Value of 0 disables fwmark (SO_MARK)
func (h *packetHandler) SetFwmark(fwmark uint) {
h.fwmark = fwmark
}

// Listen on addr for encrypted packets and basically do UDP NAT.
// We take the ciphers as a pointer because it gets replaced on config updates.
func (h *packetHandler) Handle(clientConn net.PacketConn) {
Expand Down
17 changes: 4 additions & 13 deletions service/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,19 @@ import (
"time"

"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
onet "github.com/Jigsaw-Code/outline-ss-server/net"
logging "github.com/op/go-logging"
"github.com/shadowsocks/go-shadowsocks2/socks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
onet "github.com/Jigsaw-Code/outline-ss-server/net"
)

const timeout = 5 * time.Minute

var clientAddr = net.UDPAddr{IP: []byte{192, 0, 2, 1}, Port: 12345}

var targetAddr = net.UDPAddr{IP: []byte{192, 0, 2, 2}, Port: 54321}

var dnsAddr = net.UDPAddr{IP: []byte{192, 0, 2, 3}, Port: 53}

var natCryptoKey *shadowsocks.EncryptionKey

func init() {
Expand Down Expand Up @@ -113,21 +109,16 @@ var _ UDPMetrics = (*natTestMetrics)(nil)
func (m *natTestMetrics) GetIPInfo(net.IP) (ipinfo.IPInfo, error) {
return ipinfo.IPInfo{}, nil
}

func (m *natTestMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) {
m.upstreamPackets = append(m.upstreamPackets, udpReport{clientInfo, accessKey, status, clientProxyBytes, proxyTargetBytes})
}

func (m *natTestMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) {
}

func (m *natTestMetrics) AddUDPNatEntry(clientAddr net.Addr, accessKey string) {
m.natEntriesAdded++
}

func (m *natTestMetrics) RemoveUDPNatEntry(clientAddr net.Addr, accessKey string) {
}

func (m *natTestMetrics) AddUDPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) {}

// Takes a validation policy, and returns the metrics it
Expand All @@ -137,7 +128,7 @@ func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTest
cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey
clientConn := makePacketConn()
metrics := &natTestMetrics{}
handler := NewPacketHandler(timeout, ciphers, metrics, 0)
handler := NewPacketHandler(timeout, ciphers, metrics)
handler.SetTargetIPValidator(validator)
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -489,7 +480,7 @@ func TestUDPEarlyClose(t *testing.T) {
}
testMetrics := &natTestMetrics{}
const testTimeout = 200 * time.Millisecond
s := NewPacketHandler(testTimeout, cipherList, testMetrics, 0)
s := NewPacketHandler(testTimeout, cipherList, testMetrics)

clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
Expand Down

0 comments on commit aeccec9

Please sign in to comment.