From c4d92145f91817173548a7aedca764f5655bc98e Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Mon, 13 May 2024 19:47:06 -0400 Subject: [PATCH] cleanup: clean up TCP calls and use netip (#179) --- .github/workflows/go.yml | 2 +- cmd/outline-ss-server/metrics_test.go | 6 +++++ internal/integration_test/integration_test.go | 5 ++-- service/cipher_list.go | 16 ++++++------ service/cipher_list_test.go | 12 ++++----- service/tcp.go | 25 ++++++++++--------- service/tcp_test.go | 20 +++++++-------- service/udp.go | 5 ++-- service/udp_test.go | 17 +++++++------ 9 files changed, 58 insertions(+), 50 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index c6d56208..67d0536d 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -33,4 +33,4 @@ jobs: run: go build -v ./... - name: Test - run: go test -v -race -benchmem -bench=. ./... -benchtime=100ms + run: go test -race -benchmem -bench=. ./... -benchtime=100ms diff --git a/cmd/outline-ss-server/metrics_test.go b/cmd/outline-ss-server/metrics_test.go index 8087d6e5..353520e4 100644 --- a/cmd/outline-ss-server/metrics_test.go +++ b/cmd/outline-ss-server/metrics_test.go @@ -22,6 +22,7 @@ import ( "github.com/Jigsaw-Code/outline-ss-server/ipinfo" "github.com/Jigsaw-Code/outline-ss-server/service/metrics" + "github.com/op/go-logging" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" @@ -45,6 +46,10 @@ func setNow(t time.Time) { } } +func init() { + logging.SetLevel(logging.INFO, "") +} + func TestMethodsDontPanic(t *testing.T) { ssMetrics := newPrometheusOutlineMetrics(nil, prometheus.NewPedanticRegistry()) proxyMetrics := metrics.ProxyMetrics{ @@ -149,6 +154,7 @@ func BenchmarkCloseTCP(b *testing.B) { duration := time.Minute b.ResetTimer() for i := 0; i < b.N; i++ { + ssMetrics.AddAuthenticatedTCPConnection(addr, accessKey) ssMetrics.AddClosedTCPConnection(ipinfo, addr, accessKey, status, data, duration) ssMetrics.AddTCPCipherSearch(true, timeToCipher) } diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index ec9338d6..4ca2f120 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "testing" "time" @@ -107,7 +108,7 @@ func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) { t.Logf("Failed to read from UDP conn: %v", err) return } - conn.WriteTo(buf[:n], clientAddr) + _, err = conn.WriteTo(buf[:n], clientAddr) if err != nil { t.Fatalf("Failed to write: %v", err) } @@ -335,7 +336,7 @@ func TestUDPEcho(t *testing.T) { proxyConn.Close() <-done // Verify that the expected metrics were reported. - snapshot := cipherList.SnapshotForClientIP(nil) + snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) keyID := snapshot[0].Value.(*service.CipherEntry).ID if testMetrics.natAdded != 1 { diff --git a/service/cipher_list.go b/service/cipher_list.go index cadcc40b..3b6f1957 100644 --- a/service/cipher_list.go +++ b/service/cipher_list.go @@ -16,7 +16,7 @@ package service import ( "container/list" - "net" + "net/netip" "sync" "github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks" @@ -31,7 +31,7 @@ type CipherEntry struct { ID string CryptoKey *shadowsocks.EncryptionKey SaltGenerator ServerSaltGenerator - lastClientIP net.IP + lastClientIP netip.Addr } // MakeCipherEntry constructs a CipherEntry. @@ -56,8 +56,8 @@ func MakeCipherEntry(id string, cryptoKey *shadowsocks.EncryptionKey, secret str // snapshotting and moving to front. type CipherList interface { // Returns a snapshot of the cipher list optimized for this client IP - SnapshotForClientIP(clientIP net.IP) []*list.Element - MarkUsedByClientIP(e *list.Element, clientIP net.IP) + SnapshotForClientIP(clientIP netip.Addr) []*list.Element + MarkUsedByClientIP(e *list.Element, clientIP netip.Addr) // Update replaces the current contents of the CipherList with `contents`, // which is a List of *CipherEntry. Update takes ownership of `contents`, // which must not be read or written after this call. @@ -75,12 +75,12 @@ func NewCipherList() CipherList { return &cipherList{list: list.New()} } -func matchesIP(e *list.Element, clientIP net.IP) bool { +func matchesIP(e *list.Element, clientIP netip.Addr) bool { c := e.Value.(*CipherEntry) - return clientIP != nil && clientIP.Equal(c.lastClientIP) + return clientIP != netip.Addr{} && clientIP == c.lastClientIP } -func (cl *cipherList) SnapshotForClientIP(clientIP net.IP) []*list.Element { +func (cl *cipherList) SnapshotForClientIP(clientIP netip.Addr) []*list.Element { cl.mu.RLock() defer cl.mu.RUnlock() cipherArray := make([]*list.Element, cl.list.Len()) @@ -102,7 +102,7 @@ func (cl *cipherList) SnapshotForClientIP(clientIP net.IP) []*list.Element { return cipherArray } -func (cl *cipherList) MarkUsedByClientIP(e *list.Element, clientIP net.IP) { +func (cl *cipherList) MarkUsedByClientIP(e *list.Element, clientIP netip.Addr) { cl.mu.Lock() defer cl.mu.Unlock() cl.list.MoveToFront(e) diff --git a/service/cipher_list_test.go b/service/cipher_list_test.go index 94dc23ce..00b2a9be 100644 --- a/service/cipher_list_test.go +++ b/service/cipher_list_test.go @@ -16,18 +16,18 @@ package service import ( "math/rand" - "net" + "net/netip" "testing" ) func BenchmarkLocking(b *testing.B) { - var ip net.IP + var ip netip.Addr ciphers, _ := MakeTestCiphers([]string{"secret"}) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - entries := ciphers.SnapshotForClientIP(nil) + entries := ciphers.SnapshotForClientIP(netip.Addr{}) ciphers.MarkUsedByClientIP(entries[0], ip) } }) @@ -43,20 +43,20 @@ func BenchmarkSnapshot(b *testing.B) { // Shuffling simulates the behavior of a real server, where successive // ciphers are not expected to be nearby in memory. - entries := ciphers.SnapshotForClientIP(nil) + entries := ciphers.SnapshotForClientIP(netip.Addr{}) rand.Shuffle(N, func(i, j int) { entries[i], entries[j] = entries[j], entries[i] }) for _, entry := range entries { // Reorder the list to match the shuffle // (actually in reverse, but it doesn't matter). - ciphers.MarkUsedByClientIP(entry, nil) + ciphers.MarkUsedByClientIP(entry, netip.Addr{}) } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - ciphers.SnapshotForClientIP(nil) + ciphers.SnapshotForClientIP(netip.Addr{}) } }) } diff --git a/service/tcp.go b/service/tcp.go index 9761d8f3..85ab9990 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -22,6 +22,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "syscall" "time" @@ -46,19 +47,19 @@ type TCPMetrics interface { AddTCPProbe(status, drainResult string, port int, clientProxyBytes int64) } -func remoteIP(conn net.Conn) net.IP { +func remoteIP(conn net.Conn) netip.Addr { addr := conn.RemoteAddr() if addr == nil { - return nil + return netip.Addr{} } if tcpaddr, ok := addr.(*net.TCPAddr); ok { - return tcpaddr.IP + return tcpaddr.AddrPort().Addr() } - ipstr, _, err := net.SplitHostPort(addr.String()) + addrPort, err := netip.ParseAddrPort(addr.String()) if err == nil { - return net.ParseIP(ipstr) + return addrPort.Addr() } - return nil + return netip.Addr{} } // Wrapper for logger.Debugf during TCP access key searches. @@ -76,7 +77,7 @@ func debugTCP(cipherID, template string, val interface{}) { // required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection. const bytesForKeyFinding = 50 -func findAccessKey(clientReader io.Reader, clientIP net.IP, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) { +func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) { // We snapshot the list because it may be modified while we use it. ciphers := cipherList.SnapshotForClientIP(clientIP) firstBytes := make([]byte, bytesForKeyFinding) @@ -264,7 +265,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy) connStart := time.Now() - id, connError := h.handleConnection(ctx, h.port, clientInfo, measuredClientConn, &proxyMetrics) + id, connError := h.handleConnection(ctx, measuredClientConn, &proxyMetrics) connDuration := time.Since(connStart) status := "OK" @@ -327,7 +328,7 @@ func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr return nil } -func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientInfo ipinfo.IPInfo, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) { +func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) { // Set a deadline to receive the address to the target. readDeadline := time.Now().Add(h.readTimeout) if deadline, ok := ctx.Deadline(); ok { @@ -341,7 +342,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli id, innerConn, authErr := h.authenticate(outerConn) if authErr != nil { // Drain to protect against probing attacks. - h.absorbProbe(listenerPort, outerConn, authErr.Status, proxyMetrics) + h.absorbProbe(outerConn, authErr.Status, proxyMetrics) return id, authErr } h.m.AddAuthenticatedTCPConnection(outerConn.RemoteAddr(), id) @@ -369,12 +370,12 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli // Keep the connection open until we hit the authentication deadline to protect against probing attacks // `proxyMetrics` is a pointer because its value is being mutated by `clientConn`. -func (h *tcpHandler) absorbProbe(listenerPort int, clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) { +func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) { // This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe. _, drainErr := io.Copy(io.Discard, clientConn) // drain socket drainResult := drainErrToString(drainErr) logger.Debugf("Drain error: %v, drain result: %v", drainErr, drainResult) - h.m.AddTCPProbe(status, drainResult, listenerPort, proxyMetrics.ClientProxy) + h.m.AddTCPProbe(status, drainResult, h.port, proxyMetrics.ClientProxy) } func drainErrToString(drainErr error) string { diff --git a/service/tcp_test.go b/service/tcp_test.go index e3742806..1a70ed67 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -21,6 +21,7 @@ import ( "io" "math/rand" "net" + "net/netip" "sync" "testing" "time" @@ -99,7 +100,7 @@ func BenchmarkTCPFindCipherFail(b *testing.B) { if err != nil { b.Fatalf("AcceptTCP failed: %v", err) } - clientIP := clientConn.RemoteAddr().(*net.TCPAddr).IP + clientIP := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr() b.StartTimer() findAccessKey(clientConn, clientIP, cipherList) b.StopTimer() @@ -191,16 +192,16 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) { b.Fatal(err) } cipherEntries := [numCiphers]*CipherEntry{} - snapshot := cipherList.SnapshotForClientIP(nil) + snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) for cipherNumber, element := range snapshot { cipherEntries[cipherNumber] = element.Value.(*CipherEntry) } for n := 0; n < b.N; n++ { cipherNumber := byte(n % numCiphers) reader, writer := io.Pipe() - clientIP := net.IPv4(192, 0, 2, cipherNumber) - addr := &net.TCPAddr{IP: clientIP, Port: 54321} - c := conn{clientAddr: addr, reader: reader, writer: writer} + clientIP := netip.AddrFrom4([4]byte{192, 0, 2, cipherNumber}) + addr := netip.AddrPortFrom(clientIP, 54321) + c := conn{clientAddr: net.TCPAddrFromAddrPort(addr), reader: reader, writer: writer} cipher := cipherEntries[cipherNumber].CryptoKey go shadowsocks.NewWriter(writer, cipher).Write(makeTestPayload(50)) b.StartTimer() @@ -345,7 +346,7 @@ func makeClientBytesCoalesced(t *testing.T, cryptoKey *shadowsocks.EncryptionKey } func firstCipher(cipherList CipherList) *shadowsocks.EncryptionKey { - snapshot := cipherList.SnapshotForClientIP(nil) + snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) return cipherEntry.CryptoKey } @@ -368,7 +369,6 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { discardListener, discardWait := startDiscardServer(t) initialBytes := makeClientBytesBasic(t, cipher, discardListener.Addr().String()) for numBytesToSend := 0; numBytesToSend < len(initialBytes); numBytesToSend++ { - t.Logf("Sending %v bytes", numBytesToSend) bytesToSend := initialBytes[:numBytesToSend] err := probe(listener.Addr().(*net.TCPAddr), bytesToSend) require.NoError(t, err, "Failed for %v bytes sent: %v", numBytesToSend, err) @@ -405,7 +405,6 @@ func TestProbeClientBytesBasicModified(t *testing.T) { initialBytes := makeClientBytesBasic(t, cipher, discardListener.Addr().String()) bytesToSend := make([]byte, len(initialBytes)) for byteToModify := 0; byteToModify < len(initialBytes); byteToModify++ { - t.Logf("Modifying byte %v", byteToModify) copy(bytesToSend, initialBytes) bytesToSend[byteToModify] = 255 - bytesToSend[byteToModify] err := probe(listener.Addr().(*net.TCPAddr), bytesToSend) @@ -442,7 +441,6 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { initialBytes := makeClientBytesCoalesced(t, cipher, discardListener.Addr().String()) bytesToSend := make([]byte, len(initialBytes)) for byteToModify := 0; byteToModify < len(initialBytes); byteToModify++ { - t.Logf("Modifying byte %v", byteToModify) copy(bytesToSend, initialBytes) bytesToSend[byteToModify] = 255 - bytesToSend[byteToModify] err := probe(listener.Addr().(*net.TCPAddr), bytesToSend) @@ -506,7 +504,7 @@ func TestReplayDefense(t *testing.T) { const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) - snapshot := cipherList.SnapshotForClientIP(nil) + snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey reader, writer := io.Pipe() @@ -585,7 +583,7 @@ func TestReverseReplayDefense(t *testing.T) { const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout) - snapshot := cipherList.SnapshotForClientIP(nil) + snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey reader, writer := io.Pipe() diff --git a/service/udp.go b/service/udp.go index d57b8e8a..4830e302 100644 --- a/service/udp.go +++ b/service/udp.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "net" + "net/netip" "runtime/debug" "sync" "time" @@ -64,7 +65,7 @@ func debugUDPAddr(addr net.Addr, template string, val interface{}) { // Decrypts src into dst. It tries each cipher until it finds one that authenticates // correctly. dst and src must not overlap. -func findAccessKeyUDP(clientIP net.IP, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) { +func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) { // Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD. // We snapshot the list because it may be modified while we use it. snapshot := cipherList.SnapshotForClientIP(clientIP) @@ -156,7 +157,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { } debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo) - ip := clientAddr.(*net.UDPAddr).IP + ip := clientAddr.(*net.UDPAddr).AddrPort().Addr() var textData []byte var cryptoKey *shadowsocks.EncryptionKey unpackStart := time.Now() diff --git a/service/udp_test.go b/service/udp_test.go index 30aae003..f94238c5 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -18,6 +18,7 @@ import ( "bytes" "errors" "net" + "net/netip" "sync" "testing" "time" @@ -124,7 +125,7 @@ func (m *natTestMetrics) AddUDPCipherSearch(accessKeyFound bool, timeToCipher ti // generates when localhost access is attempted func sendToDiscard(payloads [][]byte, validator onet.TargetIPValidator) *natTestMetrics { ciphers, _ := MakeTestCiphers([]string{"asdf"}) - cipher := ciphers.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).CryptoKey + cipher := ciphers.SnapshotForClientIP(netip.Addr{})[0].Value.(*CipherEntry).CryptoKey clientConn := makePacketConn() metrics := &natTestMetrics{} handler := NewPacketHandler(timeout, ciphers, metrics) @@ -403,7 +404,7 @@ func BenchmarkUDPUnpackFail(b *testing.B) { } testPayload := makeTestPayload(50) textBuf := make([]byte, serverUDPBufferSize) - testIP := net.ParseIP("192.0.2.1") + testIP := netip.MustParseAddr("192.0.2.1") b.ResetTimer() for n := 0; n < b.N; n++ { findAccessKeyUDP(testIP, textBuf, testPayload, cipherList) @@ -420,8 +421,8 @@ func BenchmarkUDPUnpackRepeat(b *testing.B) { } testBuf := make([]byte, serverUDPBufferSize) packets := [numCiphers][]byte{} - ips := [numCiphers]net.IP{} - snapshot := cipherList.SnapshotForClientIP(nil) + ips := [numCiphers]netip.Addr{} + snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) for i, element := range snapshot { packets[i] = make([]byte, 0, serverUDPBufferSize) plaintext := makeTestPayload(50) @@ -429,7 +430,7 @@ func BenchmarkUDPUnpackRepeat(b *testing.B) { if err != nil { b.Error(err) } - ips[i] = net.IPv4(192, 0, 2, byte(i)) + ips[i] = netip.AddrFrom4([4]byte{192, 0, 2, byte(i)}) } b.ResetTimer() for n := 0; n < b.N; n++ { @@ -452,15 +453,15 @@ func BenchmarkUDPUnpackSharedKey(b *testing.B) { } testBuf := make([]byte, serverUDPBufferSize) plaintext := makeTestPayload(50) - snapshot := cipherList.SnapshotForClientIP(nil) + snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cryptoKey := snapshot[0].Value.(*CipherEntry).CryptoKey packet, err := shadowsocks.Pack(make([]byte, serverUDPBufferSize), plaintext, cryptoKey) require.Nil(b, err) const numIPs = 100 // Must be <256 - ips := [numIPs]net.IP{} + ips := [numIPs]netip.Addr{} for i := 0; i < numIPs; i++ { - ips[i] = net.IPv4(192, 0, 2, byte(i)) + ips[i] = netip.AddrFrom4([4]byte{192, 0, 2, byte(i)}) } b.ResetTimer() for n := 0; n < b.N; n++ {