Skip to content

Commit

Permalink
feat: enable fwmark (SO_MARK) for outgoing sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
sabify committed Aug 25, 2024
1 parent 0f6ad5b commit 60aad45
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 45 deletions.
27 changes: 18 additions & 9 deletions cmd/outline-ss-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,20 @@ import (
"time"

"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
"github.com/Jigsaw-Code/outline-ss-server/service"
"github.com/lmittmann/tint"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/term"
"gopkg.in/yaml.v2"

"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
"github.com/Jigsaw-Code/outline-ss-server/service"
)

var logLevel = new(slog.LevelVar) // Info by default
var logHandler slog.Handler
var (
logLevel = new(slog.LevelVar) // Info by default
logHandler slog.Handler
)

// Set by goreleaser default ldflags. See https://goreleaser.com/customization/build/
var version = "dev"
Expand Down Expand Up @@ -68,7 +71,7 @@ type SSServer struct {
ports map[int]*ssPort
}

func (s *SSServer) startPort(portNum int) error {
func (s *SSServer) startPort(portNum int, fwmark uint) error {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: portNum})
if err != nil {
//lint:ignore ST1005 Shadowsocks is capitalized.
Expand All @@ -84,8 +87,8 @@ func (s *SSServer) startPort(portNum int) error {
port := &ssPort{tcpListener: listener, packetConn: packetConn, cipherList: service.NewCipherList()}
authFunc := service.NewShadowsocksStreamAuthenticator(port.cipherList, &s.replayCache, s.m)
// TODO: Register initial data metrics at zero.
tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout)
packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m)
tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout, fwmark)
packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m, fwmark)
s.ports[portNum] = port
go service.StreamServe(service.WrapStreamListener(listener.AcceptTCP), tcpHandler.Handle)
go packetHandler.Handle(port.packetConn)
Expand Down Expand Up @@ -144,15 +147,20 @@ func (s *SSServer) loadConfig(filename string) error {
return fmt.Errorf("failed to remove port %v: %w", portNum, err)
}
} else if count == +1 {
if err := s.startPort(portNum); err != nil {
if err := s.startPort(portNum, config.Fwmark); err != nil {
return err
}
}
}
for portNum, cipherList := range portCiphers {
s.ports[portNum].cipherList.Update(cipherList)
}
slog.Info("Loaded config.", "access keys", len(config.Keys), "ports", len(s.ports))
slog.Info("Loaded config.", "access keys", len(config.Keys), "ports", len(s.ports), "fwmark", func() any {
if config.Fwmark == 0 {
return "disabled"
}
return config.Fwmark
}())
s.m.SetNumAccessKeys(len(config.Keys), len(portCiphers))
return nil
}
Expand Down Expand Up @@ -199,6 +207,7 @@ type Config struct {
Cipher string
Secret string
}
Fwmark uint
}

func readConfig(filename string) (*Config, error) {
Expand Down
26 changes: 16 additions & 10 deletions internal/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ import (

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
"github.com/Jigsaw-Code/outline-ss-server/service"
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
logging "github.com/op/go-logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
"github.com/Jigsaw-Code/outline-ss-server/service"
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
)

const maxUDPPacketSize = 64 * 1024
Expand Down Expand Up @@ -133,7 +134,7 @@ func TestTCPEcho(t *testing.T) {
const testTimeout = 200 * time.Millisecond
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -202,7 +203,7 @@ func TestRestrictedAddresses(t *testing.T) {
const testTimeout = 200 * time.Millisecond
testMetrics := &statusMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
done := make(chan struct{})
go func() {
service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -262,18 +263,23 @@ var _ service.UDPMetrics = (*fakeUDPMetrics)(nil)
func (m *fakeUDPMetrics) GetIPInfo(ip net.IP) (ipinfo.IPInfo, error) {
return ipinfo.IPInfo{CountryCode: "QQ"}, nil
}

func (m *fakeUDPMetrics) AddUDPPacketFromClient(clientInfo ipinfo.IPInfo, accessKey, status string, clientProxyBytes, proxyTargetBytes int) {
m.up = append(m.up, udpRecord{clientInfo, accessKey, status, clientProxyBytes, proxyTargetBytes})
}

func (m *fakeUDPMetrics) AddUDPPacketFromTarget(clientInfo ipinfo.IPInfo, accessKey, status string, targetProxyBytes, proxyClientBytes int) {
m.down = append(m.down, udpRecord{clientInfo, accessKey, status, targetProxyBytes, proxyClientBytes})
}

func (m *fakeUDPMetrics) AddUDPNatEntry(clientAddr net.Addr, accessKey string) {
m.natAdded++
}

func (m *fakeUDPMetrics) RemoveUDPNatEntry(clientAddr net.Addr, accessKey string) {
// Not tested because it requires waiting for a long timeout.
}

func (m *fakeUDPMetrics) AddUDPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) {}

func TestUDPEcho(t *testing.T) {
Expand All @@ -289,7 +295,7 @@ func TestUDPEcho(t *testing.T) {
t.Fatal(err)
}
testMetrics := &fakeUDPMetrics{}
proxy := service.NewPacketHandler(time.Hour, cipherList, testMetrics)
proxy := service.NewPacketHandler(time.Hour, cipherList, testMetrics, 0)
proxy.SetTargetIPValidator(allowAll)
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -384,7 +390,7 @@ func BenchmarkTCPThroughput(b *testing.B) {
const testTimeout = 200 * time.Millisecond
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -448,7 +454,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) {
const testTimeout = 200 * time.Millisecond
testMetrics := &service.NoOpTCPMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout)
handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -521,7 +527,7 @@ func BenchmarkUDPEcho(b *testing.B) {
if err != nil {
b.Fatal(err)
}
proxy := service.NewPacketHandler(time.Hour, cipherList, &service.NoOpUDPMetrics{})
proxy := service.NewPacketHandler(time.Hour, cipherList, &service.NoOpUDPMetrics{}, 0)
proxy.SetTargetIPValidator(allowAll)
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -565,7 +571,7 @@ func BenchmarkUDPManyKeys(b *testing.B) {
if err != nil {
b.Fatal(err)
}
proxy := service.NewPacketHandler(time.Hour, cipherList, &service.NoOpUDPMetrics{})
proxy := service.NewPacketHandler(time.Hour, cipherList, &service.NoOpUDPMetrics{}, 0)
proxy.SetTargetIPValidator(allowAll)
done := make(chan struct{})
go func() {
Expand Down
28 changes: 23 additions & 5 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ import (
"log/slog"
"net"
"net/netip"
"os"
"sync"
"syscall"
"time"

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
"github.com/shadowsocks/go-shadowsocks2/socks"

"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
onet "github.com/Jigsaw-Code/outline-ss-server/net"
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
"github.com/shadowsocks/go-shadowsocks2/socks"
)

// TCPMetrics is used to report metrics on TCP connections.
Expand Down Expand Up @@ -170,7 +172,9 @@ type tcpHandler struct {
}

// NewTCPService creates a TCPService
func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler {
func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration, fwmark uint) TCPHandler {
defaultDialer := makeValidatingTCPStreamDialer(onet.RequirePublicIP, fwmark)

return &tcpHandler{
m: m,
readTimeout: timeout,
Expand All @@ -179,10 +183,19 @@ func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout ti
}
}

var defaultDialer = makeValidatingTCPStreamDialer(onet.RequirePublicIP)

func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) transport.StreamDialer {
func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator, fwmark uint) transport.StreamDialer {
return &transport.TCPDialer{Dialer: net.Dialer{Control: func(network, address string, c syscall.RawConn) error {
if fwmark > 0 {
err := c.Control(func(fd uintptr) {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(fwmark))
if err != nil {
slog.Error("Set fwmark failed.", "err", os.NewSyscallError("failed to set mark for TCP socket", err))
}
})
if err != nil {
slog.Error("Set TCPDialer Control func failed.", "err", err)
}
}
ip, _, _ := net.SplitHostPort(address)
return targetIPValidator(net.ParseIP(ip))
}}}
Expand Down Expand Up @@ -397,12 +410,17 @@ var _ TCPMetrics = (*NoOpTCPMetrics)(nil)

func (m *NoOpTCPMetrics) AddClosedTCPConnection(clientInfo ipinfo.IPInfo, clientAddr net.Addr, accessKey string, status string, data metrics.ProxyMetrics, duration time.Duration) {
}

func (m *NoOpTCPMetrics) GetIPInfo(net.IP) (ipinfo.IPInfo, error) {
return ipinfo.IPInfo{}, nil
}

func (m *NoOpTCPMetrics) AddOpenTCPConnection(clientInfo ipinfo.IPInfo) {}

func (m *NoOpTCPMetrics) AddAuthenticatedTCPConnection(clientAddr net.Addr, accessKey string) {
}

func (m *NoOpTCPMetrics) AddTCPProbe(status, drainResult string, listenerId string, clientProxyBytes int64) {
}

func (m *NoOpTCPMetrics) AddTCPCipherSearch(accessKeyFound bool, timeToCipher time.Duration) {}
28 changes: 15 additions & 13 deletions service/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ import (

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
logging "github.com/op/go-logging"
"github.com/shadowsocks/go-shadowsocks2/socks"
"github.com/stretchr/testify/require"

"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
)

func init() {
Expand Down Expand Up @@ -233,6 +234,7 @@ func (m *probeTestMetrics) AddClosedTCPConnection(clientInfo ipinfo.IPInfo, clie
func (m *probeTestMetrics) GetIPInfo(net.IP) (ipinfo.IPInfo, error) {
return ipinfo.IPInfo{}, nil
}

func (m *probeTestMetrics) AddOpenTCPConnection(clientInfo ipinfo.IPInfo) {
}

Expand Down Expand Up @@ -281,7 +283,7 @@ func TestProbeRandom(t *testing.T) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond, 0)
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -358,8 +360,8 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond, 0)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0))
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -393,8 +395,8 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond, 0)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0))
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -429,8 +431,8 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond, 0)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll, 0))
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -472,7 +474,7 @@ func TestProbeServerBytesModified(t *testing.T) {
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond)
handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond, 0)
done := make(chan struct{})
go func() {
StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle)
Expand Down Expand Up @@ -503,7 +505,7 @@ func TestReplayDefense(t *testing.T) {
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
Expand Down Expand Up @@ -582,7 +584,7 @@ func TestReverseReplayDefense(t *testing.T) {
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout, 0)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
Expand Down Expand Up @@ -653,7 +655,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout)
handler := NewTCPHandler(authFunc, testMetrics, testTimeout, 0)

done := make(chan struct{})
go func() {
Expand Down
Loading

0 comments on commit 60aad45

Please sign in to comment.