Skip to content

Commit

Permalink
cleanup: clean up TCP calls and use netip (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored May 13, 2024
1 parent 4c35a51 commit c4d9214
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 50 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ jobs:
run: go build -v ./...

- name: Test
run: go test -v -race -benchmem -bench=. ./... -benchtime=100ms
run: go test -race -benchmem -bench=. ./... -benchtime=100ms
6 changes: 6 additions & 0 deletions cmd/outline-ss-server/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
"github.com/op/go-logging"
"github.com/prometheus/client_golang/prometheus"
promtest "github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"
Expand All @@ -45,6 +46,10 @@ func setNow(t time.Time) {
}
}

func init() {
logging.SetLevel(logging.INFO, "")
}

func TestMethodsDontPanic(t *testing.T) {
ssMetrics := newPrometheusOutlineMetrics(nil, prometheus.NewPedanticRegistry())
proxyMetrics := metrics.ProxyMetrics{
Expand Down Expand Up @@ -149,6 +154,7 @@ func BenchmarkCloseTCP(b *testing.B) {
duration := time.Minute
b.ResetTimer()
for i := 0; i < b.N; i++ {
ssMetrics.AddAuthenticatedTCPConnection(addr, accessKey)
ssMetrics.AddClosedTCPConnection(ipinfo, addr, accessKey, status, data, duration)
ssMetrics.AddTCPCipherSearch(true, timeToCipher)
}
Expand Down
5 changes: 3 additions & 2 deletions internal/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -107,7 +108,7 @@ func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) {
t.Logf("Failed to read from UDP conn: %v", err)
return
}
conn.WriteTo(buf[:n], clientAddr)
_, err = conn.WriteTo(buf[:n], clientAddr)
if err != nil {
t.Fatalf("Failed to write: %v", err)
}
Expand Down Expand Up @@ -335,7 +336,7 @@ func TestUDPEcho(t *testing.T) {
proxyConn.Close()
<-done
// Verify that the expected metrics were reported.
snapshot := cipherList.SnapshotForClientIP(nil)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
keyID := snapshot[0].Value.(*service.CipherEntry).ID

if testMetrics.natAdded != 1 {
Expand Down
16 changes: 8 additions & 8 deletions service/cipher_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package service

import (
"container/list"
"net"
"net/netip"
"sync"

"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
Expand All @@ -31,7 +31,7 @@ type CipherEntry struct {
ID string
CryptoKey *shadowsocks.EncryptionKey
SaltGenerator ServerSaltGenerator
lastClientIP net.IP
lastClientIP netip.Addr
}

// MakeCipherEntry constructs a CipherEntry.
Expand All @@ -56,8 +56,8 @@ func MakeCipherEntry(id string, cryptoKey *shadowsocks.EncryptionKey, secret str
// snapshotting and moving to front.
type CipherList interface {
// Returns a snapshot of the cipher list optimized for this client IP
SnapshotForClientIP(clientIP net.IP) []*list.Element
MarkUsedByClientIP(e *list.Element, clientIP net.IP)
SnapshotForClientIP(clientIP netip.Addr) []*list.Element
MarkUsedByClientIP(e *list.Element, clientIP netip.Addr)
// Update replaces the current contents of the CipherList with `contents`,
// which is a List of *CipherEntry. Update takes ownership of `contents`,
// which must not be read or written after this call.
Expand All @@ -75,12 +75,12 @@ func NewCipherList() CipherList {
return &cipherList{list: list.New()}
}

func matchesIP(e *list.Element, clientIP net.IP) bool {
func matchesIP(e *list.Element, clientIP netip.Addr) bool {
c := e.Value.(*CipherEntry)
return clientIP != nil && clientIP.Equal(c.lastClientIP)
return clientIP != netip.Addr{} && clientIP == c.lastClientIP
}

func (cl *cipherList) SnapshotForClientIP(clientIP net.IP) []*list.Element {
func (cl *cipherList) SnapshotForClientIP(clientIP netip.Addr) []*list.Element {
cl.mu.RLock()
defer cl.mu.RUnlock()
cipherArray := make([]*list.Element, cl.list.Len())
Expand All @@ -102,7 +102,7 @@ func (cl *cipherList) SnapshotForClientIP(clientIP net.IP) []*list.Element {
return cipherArray
}

func (cl *cipherList) MarkUsedByClientIP(e *list.Element, clientIP net.IP) {
func (cl *cipherList) MarkUsedByClientIP(e *list.Element, clientIP netip.Addr) {
cl.mu.Lock()
defer cl.mu.Unlock()
cl.list.MoveToFront(e)
Expand Down
12 changes: 6 additions & 6 deletions service/cipher_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ package service

import (
"math/rand"
"net"
"net/netip"
"testing"
)

func BenchmarkLocking(b *testing.B) {
var ip net.IP
var ip netip.Addr

ciphers, _ := MakeTestCiphers([]string{"secret"})
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
entries := ciphers.SnapshotForClientIP(nil)
entries := ciphers.SnapshotForClientIP(netip.Addr{})
ciphers.MarkUsedByClientIP(entries[0], ip)
}
})
Expand All @@ -43,20 +43,20 @@ func BenchmarkSnapshot(b *testing.B) {

// Shuffling simulates the behavior of a real server, where successive
// ciphers are not expected to be nearby in memory.
entries := ciphers.SnapshotForClientIP(nil)
entries := ciphers.SnapshotForClientIP(netip.Addr{})
rand.Shuffle(N, func(i, j int) {
entries[i], entries[j] = entries[j], entries[i]
})
for _, entry := range entries {
// Reorder the list to match the shuffle
// (actually in reverse, but it doesn't matter).
ciphers.MarkUsedByClientIP(entry, nil)
ciphers.MarkUsedByClientIP(entry, netip.Addr{})
}

b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ciphers.SnapshotForClientIP(nil)
ciphers.SnapshotForClientIP(netip.Addr{})
}
})
}
25 changes: 13 additions & 12 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"sync"
"syscall"
"time"
Expand All @@ -46,19 +47,19 @@ type TCPMetrics interface {
AddTCPProbe(status, drainResult string, port int, clientProxyBytes int64)
}

func remoteIP(conn net.Conn) net.IP {
func remoteIP(conn net.Conn) netip.Addr {
addr := conn.RemoteAddr()
if addr == nil {
return nil
return netip.Addr{}
}
if tcpaddr, ok := addr.(*net.TCPAddr); ok {
return tcpaddr.IP
return tcpaddr.AddrPort().Addr()
}
ipstr, _, err := net.SplitHostPort(addr.String())
addrPort, err := netip.ParseAddrPort(addr.String())
if err == nil {
return net.ParseIP(ipstr)
return addrPort.Addr()
}
return nil
return netip.Addr{}
}

// Wrapper for logger.Debugf during TCP access key searches.
Expand All @@ -76,7 +77,7 @@ func debugTCP(cipherID, template string, val interface{}) {
// required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection.
const bytesForKeyFinding = 50

func findAccessKey(clientReader io.Reader, clientIP net.IP, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
// We snapshot the list because it may be modified while we use it.
ciphers := cipherList.SnapshotForClientIP(clientIP)
firstBytes := make([]byte, bytesForKeyFinding)
Expand Down Expand Up @@ -264,7 +265,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn
measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy)
connStart := time.Now()

id, connError := h.handleConnection(ctx, h.port, clientInfo, measuredClientConn, &proxyMetrics)
id, connError := h.handleConnection(ctx, measuredClientConn, &proxyMetrics)

connDuration := time.Since(connStart)
status := "OK"
Expand Down Expand Up @@ -327,7 +328,7 @@ func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr
return nil
}

func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientInfo ipinfo.IPInfo, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) {
// Set a deadline to receive the address to the target.
readDeadline := time.Now().Add(h.readTimeout)
if deadline, ok := ctx.Deadline(); ok {
Expand All @@ -341,7 +342,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli
id, innerConn, authErr := h.authenticate(outerConn)
if authErr != nil {
// Drain to protect against probing attacks.
h.absorbProbe(listenerPort, outerConn, authErr.Status, proxyMetrics)
h.absorbProbe(outerConn, authErr.Status, proxyMetrics)
return id, authErr
}
h.m.AddAuthenticatedTCPConnection(outerConn.RemoteAddr(), id)
Expand Down Expand Up @@ -369,12 +370,12 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli

// Keep the connection open until we hit the authentication deadline to protect against probing attacks
// `proxyMetrics` is a pointer because its value is being mutated by `clientConn`.
func (h *tcpHandler) absorbProbe(listenerPort int, clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) {
func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, status string, proxyMetrics *metrics.ProxyMetrics) {
// This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe.
_, drainErr := io.Copy(io.Discard, clientConn) // drain socket
drainResult := drainErrToString(drainErr)
logger.Debugf("Drain error: %v, drain result: %v", drainErr, drainResult)
h.m.AddTCPProbe(status, drainResult, listenerPort, proxyMetrics.ClientProxy)
h.m.AddTCPProbe(status, drainResult, h.port, proxyMetrics.ClientProxy)
}

func drainErrToString(drainErr error) string {
Expand Down
20 changes: 9 additions & 11 deletions service/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"io"
"math/rand"
"net"
"net/netip"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -99,7 +100,7 @@ func BenchmarkTCPFindCipherFail(b *testing.B) {
if err != nil {
b.Fatalf("AcceptTCP failed: %v", err)
}
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).IP
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr()
b.StartTimer()
findAccessKey(clientConn, clientIP, cipherList)
b.StopTimer()
Expand Down Expand Up @@ -191,16 +192,16 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) {
b.Fatal(err)
}
cipherEntries := [numCiphers]*CipherEntry{}
snapshot := cipherList.SnapshotForClientIP(nil)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
for cipherNumber, element := range snapshot {
cipherEntries[cipherNumber] = element.Value.(*CipherEntry)
}
for n := 0; n < b.N; n++ {
cipherNumber := byte(n % numCiphers)
reader, writer := io.Pipe()
clientIP := net.IPv4(192, 0, 2, cipherNumber)
addr := &net.TCPAddr{IP: clientIP, Port: 54321}
c := conn{clientAddr: addr, reader: reader, writer: writer}
clientIP := netip.AddrFrom4([4]byte{192, 0, 2, cipherNumber})
addr := netip.AddrPortFrom(clientIP, 54321)
c := conn{clientAddr: net.TCPAddrFromAddrPort(addr), reader: reader, writer: writer}
cipher := cipherEntries[cipherNumber].CryptoKey
go shadowsocks.NewWriter(writer, cipher).Write(makeTestPayload(50))
b.StartTimer()
Expand Down Expand Up @@ -345,7 +346,7 @@ func makeClientBytesCoalesced(t *testing.T, cryptoKey *shadowsocks.EncryptionKey
}

func firstCipher(cipherList CipherList) *shadowsocks.EncryptionKey {
snapshot := cipherList.SnapshotForClientIP(nil)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
return cipherEntry.CryptoKey
}
Expand All @@ -368,7 +369,6 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) {
discardListener, discardWait := startDiscardServer(t)
initialBytes := makeClientBytesBasic(t, cipher, discardListener.Addr().String())
for numBytesToSend := 0; numBytesToSend < len(initialBytes); numBytesToSend++ {
t.Logf("Sending %v bytes", numBytesToSend)
bytesToSend := initialBytes[:numBytesToSend]
err := probe(listener.Addr().(*net.TCPAddr), bytesToSend)
require.NoError(t, err, "Failed for %v bytes sent: %v", numBytesToSend, err)
Expand Down Expand Up @@ -405,7 +405,6 @@ func TestProbeClientBytesBasicModified(t *testing.T) {
initialBytes := makeClientBytesBasic(t, cipher, discardListener.Addr().String())
bytesToSend := make([]byte, len(initialBytes))
for byteToModify := 0; byteToModify < len(initialBytes); byteToModify++ {
t.Logf("Modifying byte %v", byteToModify)
copy(bytesToSend, initialBytes)
bytesToSend[byteToModify] = 255 - bytesToSend[byteToModify]
err := probe(listener.Addr().(*net.TCPAddr), bytesToSend)
Expand Down Expand Up @@ -442,7 +441,6 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) {
initialBytes := makeClientBytesCoalesced(t, cipher, discardListener.Addr().String())
bytesToSend := make([]byte, len(initialBytes))
for byteToModify := 0; byteToModify < len(initialBytes); byteToModify++ {
t.Logf("Modifying byte %v", byteToModify)
copy(bytesToSend, initialBytes)
bytesToSend[byteToModify] = 255 - bytesToSend[byteToModify]
err := probe(listener.Addr().(*net.TCPAddr), bytesToSend)
Expand Down Expand Up @@ -506,7 +504,7 @@ func TestReplayDefense(t *testing.T) {
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
snapshot := cipherList.SnapshotForClientIP(nil)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
reader, writer := io.Pipe()
Expand Down Expand Up @@ -585,7 +583,7 @@ func TestReverseReplayDefense(t *testing.T) {
const testTimeout = 200 * time.Millisecond
authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics)
handler := NewTCPHandler(listener.Addr().(*net.TCPAddr).Port, authFunc, testMetrics, testTimeout)
snapshot := cipherList.SnapshotForClientIP(nil)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cipherEntry := snapshot[0].Value.(*CipherEntry)
cipher := cipherEntry.CryptoKey
reader, writer := io.Pipe()
Expand Down
5 changes: 3 additions & 2 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"runtime/debug"
"sync"
"time"
Expand Down Expand Up @@ -64,7 +65,7 @@ func debugUDPAddr(addr net.Addr, template string, val interface{}) {

// Decrypts src into dst. It tries each cipher until it finds one that authenticates
// correctly. dst and src must not overlap.
func findAccessKeyUDP(clientIP net.IP, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) {
func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) {
// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
// We snapshot the list because it may be modified while we use it.
snapshot := cipherList.SnapshotForClientIP(clientIP)
Expand Down Expand Up @@ -156,7 +157,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) {
}
debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo)

ip := clientAddr.(*net.UDPAddr).IP
ip := clientAddr.(*net.UDPAddr).AddrPort().Addr()
var textData []byte
var cryptoKey *shadowsocks.EncryptionKey
unpackStart := time.Now()
Expand Down
Loading

0 comments on commit c4d9214

Please sign in to comment.