Skip to content

Commit

Permalink
refactor: pass in logger to service so caller can control logs (#207)
Browse files Browse the repository at this point in the history
* refactor: create re-usable service that can be re-used by Caddy

* Remove need to return errors in opt functions.

* Move the service into `shadowsocks.go`.

* refactor: pass in logger to service so caller can control logs

* Move initialization of handlers to the constructor.

* Pass a `list.List` instead of a `CipherList`.

* Rename `SSServer` to `OutlineServer`.

* refactor: make connection metrics optional

* Make setting the logger a setter function.

* Revert "Pass a `list.List` instead of a `CipherList`."

This reverts commit 1259af8.

* Create noop metrics if nil.

* Revert some more changes.

* Use a noop metrics struct if no metrics provided.

* Add noop implementation for `ShadowsocksConnMetrics`.

* Move logger arg.

* Resolve nil metrics.

* Set logger explicitly to `noopLogger` in service creation.

* Set `noopLogger` in `NewShadowsocksStreamAuthenticator()` if nil.

* Fix logger reference.

* Use a `noopLogger` if `SetLogger()` is called with `nil`.

* Update tests.

* Use concrete `slog.Logger` instead of `Logger` interface now that we don't need a zap adapter for Caddy.

* Move `WithLogger()` down.

* Remove `nil` check.

* Use `math.MaxInt` to make sure no error log records are created.
  • Loading branch information
sbruens authored Sep 23, 2024
1 parent e8ec4d0 commit 5d3d6db
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 57 deletions.
2 changes: 2 additions & 0 deletions cmd/outline-ss-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) {
service.WithNatTimeout(s.natTimeout),
service.WithMetrics(s.serviceMetrics),
service.WithReplayCache(&s.replayCache),
service.WithLogger(slog.Default()),
)
ln, err := lnSet.ListenStream(addr)
if err != nil {
Expand All @@ -248,6 +249,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) {
service.WithNatTimeout(s.natTimeout),
service.WithMetrics(s.serviceMetrics),
service.WithReplayCache(&s.replayCache),
service.WithLogger(slog.Default()),
)
if err != nil {
return err
Expand Down
8 changes: 4 additions & 4 deletions internal/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestTCPEcho(t *testing.T) {
replayCache := service.NewReplayCache(5)
const testTimeout = 200 * time.Millisecond
testMetrics := &statusMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, &fakeShadowsocksMetrics{})
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, &fakeShadowsocksMetrics{}, nil)
handler := service.NewStreamHandler(authFunc, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
Expand Down Expand Up @@ -211,7 +211,7 @@ func TestRestrictedAddresses(t *testing.T) {
require.NoError(t, err)
const testTimeout = 200 * time.Millisecond
testMetrics := &statusMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{})
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil)
handler := service.NewStreamHandler(authFunc, testTimeout)
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -400,7 +400,7 @@ func BenchmarkTCPThroughput(b *testing.B) {
}
const testTimeout = 200 * time.Millisecond
testMetrics := &service.NoOpTCPConnMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{})
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil)
handler := service.NewStreamHandler(authFunc, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
Expand Down Expand Up @@ -467,7 +467,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) {
replayCache := service.NewReplayCache(service.MaxCapacity)
const testTimeout = 200 * time.Millisecond
testMetrics := &service.NoOpTCPConnMetrics{}
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, &fakeShadowsocksMetrics{})
authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, &fakeShadowsocksMetrics{}, nil)
handler := service.NewStreamHandler(authFunc, testTimeout)
handler.SetTargetDialer(&transport.TCPDialer{})
done := make(chan struct{})
Expand Down
11 changes: 9 additions & 2 deletions service/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@

package service

import logging "github.com/op/go-logging"
import (
"io"
"log/slog"
"math"
)

var logger = logging.MustGetLogger("shadowsocks")
func noopLogger() *slog.Logger {
// TODO: Use built-in no-op log level when available: https://go.dev/issue/62005
return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.Level(math.MaxInt)}))
}
22 changes: 20 additions & 2 deletions service/shadowsocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package service

import (
"context"
"log/slog"
"net"
"time"

Expand Down Expand Up @@ -50,6 +51,7 @@ type Service interface {
type Option func(s *ssService)

type ssService struct {
logger *slog.Logger
metrics ServiceMetrics
ciphers CipherList
natTimeout time.Duration
Expand All @@ -59,28 +61,44 @@ type ssService struct {
ph PacketHandler
}

// NewShadowsocksService creates a new service
// NewShadowsocksService creates a new Shadowsocks service.
func NewShadowsocksService(opts ...Option) (Service, error) {
s := &ssService{}

for _, opt := range opts {
opt(s)
}

// If no NAT timeout is provided via options, use the recommended default.
if s.natTimeout == 0 {
s.natTimeout = defaultNatTimeout
}
// If no logger is provided via options, use a noop logger.
if s.logger == nil {
s.logger = noopLogger()
}

// TODO: Register initial data metrics at zero.
s.sh = NewStreamHandler(
NewShadowsocksStreamAuthenticator(s.ciphers, s.replayCache, &ssConnMetrics{ServiceMetrics: s.metrics, proto: "tcp"}),
NewShadowsocksStreamAuthenticator(s.ciphers, s.replayCache, &ssConnMetrics{ServiceMetrics: s.metrics, proto: "tcp"}, s.logger),
tcpReadTimeout,
)
s.sh.SetLogger(s.logger)

s.ph = NewPacketHandler(s.natTimeout, s.ciphers, s.metrics, &ssConnMetrics{ServiceMetrics: s.metrics, proto: "udp"})
s.ph.SetLogger(s.logger)

return s, nil
}

// WithLogger can be used to provide a custom log target. If not provided,
// the service uses a noop logger (i.e., no logging).
func WithLogger(l *slog.Logger) Option {
return func(s *ssService) {
s.logger = l
}
}

// WithCiphers option function.
func WithCiphers(ciphers CipherList) Option {
return func(s *ssService) {
Expand Down
46 changes: 30 additions & 16 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ func remoteIP(conn net.Conn) netip.Addr {
}

// Wrapper for slog.Debug during TCP access key searches.
func debugTCP(template string, cipherID string, attr slog.Attr) {
func debugTCP(l *slog.Logger, template string, cipherID string, attr slog.Attr) {
// This is an optimization to reduce unnecessary allocations due to an interaction
// between Go's inlining/escape analysis and varargs functions like slog.Debug.
if slog.Default().Enabled(nil, slog.LevelDebug) {
slog.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("TCP: %s", template), slog.String("ID", cipherID), attr)
if l.Enabled(nil, slog.LevelDebug) {
l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("TCP: %s", template), slog.String("ID", cipherID), attr)
}
}

Expand All @@ -72,7 +72,7 @@ func debugTCP(template string, cipherID string, attr slog.Attr) {
// required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection.
const bytesForKeyFinding = 50

func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList, l *slog.Logger) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
// We snapshot the list because it may be modified while we use it.
ciphers := cipherList.SnapshotForClientIP(clientIP)
firstBytes := make([]byte, bytesForKeyFinding)
Expand All @@ -81,7 +81,7 @@ func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList Ciphe
}

findStartTime := time.Now()
entry, elt := findEntry(firstBytes, ciphers)
entry, elt := findEntry(firstBytes, ciphers, l)
timeToCipher := time.Since(findStartTime)
if entry == nil {
// TODO: Ban and log client IPs with too many failures too quick to protect against DoS.
Expand All @@ -95,18 +95,18 @@ func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList Ciphe
}

// Implements a trial decryption search. This assumes that all ciphers are AEAD.
func findEntry(firstBytes []byte, ciphers []*list.Element) (*CipherEntry, *list.Element) {
func findEntry(firstBytes []byte, ciphers []*list.Element, l *slog.Logger) (*CipherEntry, *list.Element) {
// To hold the decrypted chunk length.
chunkLenBuf := [2]byte{}
for ci, elt := range ciphers {
entry := elt.Value.(*CipherEntry)
cryptoKey := entry.CryptoKey
_, err := shadowsocks.Unpack(chunkLenBuf[:0], firstBytes[:cryptoKey.SaltSize()+2+cryptoKey.TagSize()], cryptoKey)
if err != nil {
debugTCP("Failed to decrypt length.", entry.ID, slog.Any("err", err))
debugTCP(l, "Failed to decrypt length.", entry.ID, slog.Any("err", err))
continue
}
debugTCP("Found cipher.", entry.ID, slog.Int("index", ci))
debugTCP(l, "Found cipher.", entry.ID, slog.Int("index", ci))
return entry, elt
}
return nil, nil
Expand All @@ -116,13 +116,16 @@ type StreamAuthenticateFunc func(clientConn transport.StreamConn) (string, trans

// NewShadowsocksStreamAuthenticator creates a stream authenticator that uses Shadowsocks.
// TODO(fortuna): Offer alternative transports.
func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksConnMetrics) StreamAuthenticateFunc {
func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksConnMetrics, l *slog.Logger) StreamAuthenticateFunc {
if metrics == nil {
metrics = &NoOpShadowsocksConnMetrics{}
}
if l == nil {
l = noopLogger()
}
return func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError) {
// Find the cipher and acess key id.
cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), ciphers)
cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), ciphers, l)
metrics.AddCipherSearch(keyErr == nil, timeToCipher)
if keyErr != nil {
const status = "ERR_CIPHER"
Expand Down Expand Up @@ -154,6 +157,7 @@ func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCa
}

type streamHandler struct {
logger *slog.Logger
listenerId string
readTimeout time.Duration
authenticate StreamAuthenticateFunc
Expand All @@ -163,6 +167,7 @@ type streamHandler struct {
// NewStreamHandler creates a StreamHandler
func NewStreamHandler(authenticate StreamAuthenticateFunc, timeout time.Duration) StreamHandler {
return &streamHandler{
logger: noopLogger(),
readTimeout: timeout,
authenticate: authenticate,
dialer: defaultDialer,
Expand All @@ -181,10 +186,19 @@ func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) tra
// StreamHandler is a handler that handles stream connections.
type StreamHandler interface {
Handle(ctx context.Context, conn transport.StreamConn, connMetrics TCPConnMetrics)
// SetLogger sets the logger used to log messages. Uses a no-op logger if nil.
SetLogger(l *slog.Logger)
// SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses.
SetTargetDialer(dialer transport.StreamDialer)
}

func (s *streamHandler) SetLogger(l *slog.Logger) {
if l == nil {
l = noopLogger()
}
s.logger = l
}

func (s *streamHandler) SetTargetDialer(dialer transport.StreamDialer) {
s.dialer = dialer
}
Expand Down Expand Up @@ -257,11 +271,11 @@ func (h *streamHandler) Handle(ctx context.Context, clientConn transport.StreamC
status := "OK"
if connError != nil {
status = connError.Status
slog.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
}
connMetrics.AddClosed(status, proxyMetrics, connDuration)
measuredClientConn.Close() // Closing after the metrics are added aids integration testing.
slog.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration))
h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration))
}

func getProxyRequest(clientConn transport.StreamConn) (string, error) {
Expand All @@ -276,14 +290,14 @@ func getProxyRequest(clientConn transport.StreamConn) (string, error) {
return tgtAddr.String(), nil
}

func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError {
func proxyConnection(l *slog.Logger, ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError {
tgtConn, dialErr := dialer.DialStream(ctx, tgtAddr)
if dialErr != nil {
// We don't drain so dial errors and invalid addresses are communicated quickly.
return ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target")
}
defer tgtConn.Close()
slog.LogAttrs(nil, slog.LevelDebug, "Proxy connection.", slog.String("client", clientConn.RemoteAddr().String()), slog.String("target", tgtConn.RemoteAddr().String()))
l.LogAttrs(nil, slog.LevelDebug, "Proxy connection.", slog.String("client", clientConn.RemoteAddr().String()), slog.String("target", tgtConn.RemoteAddr().String()))

fromClientErrCh := make(chan error)
go func() {
Expand Down Expand Up @@ -351,7 +365,7 @@ func (h *streamHandler) handleConnection(ctx context.Context, outerConn transpor
tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
return tgtConn, nil
})
return proxyConnection(ctx, dialer, tgtAddr, innerConn)
return proxyConnection(h.logger, ctx, dialer, tgtAddr, innerConn)
}

// Keep the connection open until we hit the authentication deadline to protect against probing attacks
Expand All @@ -360,7 +374,7 @@ func (h *streamHandler) absorbProbe(clientConn io.ReadCloser, connMetrics TCPCon
// This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe.
_, drainErr := io.Copy(io.Discard, clientConn) // drain socket
drainResult := drainErrToString(drainErr)
slog.LogAttrs(nil, slog.LevelDebug, "Drain error.", slog.Any("err", drainErr), slog.String("result", drainResult))
h.logger.LogAttrs(nil, slog.LevelDebug, "Drain error.", slog.Any("err", drainErr), slog.String("result", drainResult))
connMetrics.AddProbe(status, drainResult, proxyMetrics.ClientProxy)
}

Expand Down
20 changes: 10 additions & 10 deletions service/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func BenchmarkTCPFindCipherFail(b *testing.B) {
}
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr()
b.StartTimer()
findAccessKey(clientConn, clientIP, cipherList)
findAccessKey(clientConn, clientIP, cipherList, noopLogger())
b.StopTimer()
}
}
Expand Down Expand Up @@ -205,7 +205,7 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) {
cipher := cipherEntries[cipherNumber].CryptoKey
go shadowsocks.NewWriter(writer, cipher).Write(makeTestPayload(50))
b.StartTimer()
_, _, _, _, err := findAccessKey(&c, clientIP, cipherList)
_, _, _, _, err := findAccessKey(&c, clientIP, cipherList, noopLogger())
b.StopTimer()
if err != nil {
b.Error(err)
Expand Down Expand Up @@ -285,7 +285,7 @@ func TestProbeRandom(t *testing.T) {
cipherList, err := MakeTestCiphers(makeTestSecrets(1))
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{})
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil)
handler := NewStreamHandler(authFunc, 200*time.Millisecond)
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -365,7 +365,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{})
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil)
handler := NewStreamHandler(authFunc, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
Expand Down Expand Up @@ -403,7 +403,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{})
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil)
handler := NewStreamHandler(authFunc, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
Expand Down Expand Up @@ -442,7 +442,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{})
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil)
handler := NewStreamHandler(authFunc, 200*time.Millisecond)
handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll))
done := make(chan struct{})
Expand Down Expand Up @@ -488,7 +488,7 @@ func TestProbeServerBytesModified(t *testing.T) {
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
cipher := firstCipher(cipherList)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{})
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil)
handler := NewStreamHandler(authFunc, 200*time.Millisecond)
done := make(chan struct{})
go func() {
Expand Down Expand Up @@ -522,7 +522,7 @@ func TestReplayDefense(t *testing.T) {
replayCache := NewReplayCache(5)
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics, nil)
handler := NewStreamHandler(authFunc, testTimeout)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
Expand Down Expand Up @@ -604,7 +604,7 @@ func TestReverseReplayDefense(t *testing.T) {
replayCache := NewReplayCache(5)
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics, nil)
handler := NewStreamHandler(authFunc, testTimeout)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
Expand Down Expand Up @@ -678,7 +678,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) {
cipherList, err := MakeTestCiphers(makeTestSecrets(5))
require.NoError(t, err, "MakeTestCiphers failed: %v", err)
testMetrics := &probeTestMetrics{}
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{})
authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil)
handler := NewStreamHandler(authFunc, testTimeout)

done := make(chan struct{})
Expand Down
Loading

0 comments on commit 5d3d6db

Please sign in to comment.