From 5dd79fd2e69ca9ea03f8577c072f5bdbdfc16908 Mon Sep 17 00:00:00 2001 From: Mark Holt Date: Mon, 15 Jul 2024 09:36:00 +0100 Subject: [PATCH] fix races --- client.go | 7 +++++- peerconn.go | 26 +++++++++++++++------- peerconn_test.go | 4 ++-- pex.go | 8 +++---- pex_test.go | 56 ++++++++++++++++++++++++------------------------ torrent.go | 8 +++---- 6 files changed, 62 insertions(+), 47 deletions(-) diff --git a/client.go b/client.go index 6867cbafdd..5996593414 100644 --- a/client.go +++ b/client.go @@ -1249,13 +1249,18 @@ func (cl *Client) gotMetadataExtensionMsg(payload []byte, t *Torrent, c *PeerCon if !c.requestedMetadataPiece(piece) { return fmt.Errorf("got unexpected piece %d", piece) } + + c.mu.Lock() c.metadataRequests[piece] = false + c.lastUsefulChunkReceived = time.Now() + c.mu.Unlock() + // begin := len(payload) - d.PieceSize() if begin < 0 || begin >= len(payload) { return fmt.Errorf("data has bad offset in payload: %d", begin) } t.saveMetadataPiece(piece, payload[begin:]) - c.lastUsefulChunkReceived = time.Now() + err = t.maybeCompleteMetadata() if err != nil { // Log this at the Torrent-level, as we don't partition metadata by Peer yet, so we diff --git a/peerconn.go b/peerconn.go index 9d0a73fbfa..c712e2af2f 100644 --- a/peerconn.go +++ b/peerconn.go @@ -1099,7 +1099,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err c.requestPendingMetadata(false) if !t.cl.config.DisablePEX { c.mu.Lock() - t.pex.Add(c) // we learnt enough now + t.pex.Add(c, false) // we learnt enough now // This checks the extension is supported internally. c.pex.Init(c) c.mu.Unlock() @@ -1295,7 +1295,12 @@ func (c *Peer) setTorrent(t *Torrent, lockTorrent bool) { t.reconcileHandshakeStats(c) } -func (c *PeerConn) pexPeerFlags() pp.PexPeerFlags { +func (c *PeerConn) pexPeerFlags(lock bool) pp.PexPeerFlags { + if lock { + c.mu.RLock() + c.mu.RUnlock() + } + f := pp.PexPeerFlags(0) if c.PeerPrefersEncryption { f |= pp.PexPrefersEncryption @@ -1311,7 +1316,12 @@ func (c *PeerConn) pexPeerFlags() pp.PexPeerFlags { // This returns the address to use if we want to dial the peer again. It incorporates the peer's // advertised listen port. -func (c *PeerConn) dialAddr() PeerRemoteAddr { +func (c *PeerConn) dialAddr(lock bool) PeerRemoteAddr { + if lock { + c.mu.RLock() + c.mu.RUnlock() + } + if c.outgoing || c.PeerListenPort == 0 { return c.RemoteAddr } @@ -1328,9 +1338,9 @@ func (c *PeerConn) dialAddr() PeerRemoteAddr { return netip.AddrPortFrom(addrPort.Addr(), uint16(c.PeerListenPort)) } -func (c *PeerConn) pexEvent(t pexEventType) (_ pexEvent, err error) { - f := c.pexPeerFlags() - dialAddr := c.dialAddr() +func (c *PeerConn) pexEvent(t pexEventType, lock bool) (_ pexEvent, err error) { + f := c.pexPeerFlags(lock) + dialAddr := c.dialAddr(lock) addr, err := addrPortFromPeerRemoteAddr(dialAddr) if err != nil || !addr.IsValid() { err = fmt.Errorf("parsing dial addr %q: %w", dialAddr, err) @@ -1353,8 +1363,8 @@ func (pc *PeerConn) remoteIsTransmission() bool { return bytes.HasPrefix(pc.PeerID[:], []byte("-TR")) && pc.PeerID[7] == '-' } -func (pc *PeerConn) remoteDialAddrPort() (netip.AddrPort, error) { - dialAddr := pc.dialAddr() +func (pc *PeerConn) remoteDialAddrPort(lock bool) (netip.AddrPort, error) { + dialAddr := pc.dialAddr(lock) return addrPortFromPeerRemoteAddr(dialAddr) } diff --git a/peerconn_test.go b/peerconn_test.go index 71611fd7f6..dce361d117 100644 --- a/peerconn_test.go +++ b/peerconn_test.go @@ -212,7 +212,7 @@ func TestConnPexPeerFlags(t *testing.T) { {&PeerConn{Peer: Peer{RemoteAddr: tcpAddr, Network: tcpAddr.Network()}}, 0}, } for i, tc := range testcases { - f := tc.conn.pexPeerFlags() + f := tc.conn.pexPeerFlags(true) require.EqualValues(t, tc.f, f, i) } } @@ -262,7 +262,7 @@ func TestConnPexEvent(t *testing.T) { } for i, tc := range testcases { c.Run(fmt.Sprintf("%v", i), func(c *qt.C) { - e, err := tc.c.pexEvent(tc.t) + e, err := tc.c.pexEvent(tc.t, true) c.Assert(err, qt.IsNil) c.Check(e, qt.Equals, tc.e) }) diff --git a/pex.go b/pex.go index a0a5f49f5b..a37885845f 100644 --- a/pex.go +++ b/pex.go @@ -173,8 +173,8 @@ func (s *pexState) append(e *pexEvent) { s.msg0.append(*e) } -func (s *pexState) Add(c *PeerConn) { - e, err := c.pexEvent(pexAdd) +func (s *pexState) Add(c *PeerConn, lockPeer bool) { + e, err := c.pexEvent(pexAdd, lockPeer) if err != nil { return } @@ -192,12 +192,12 @@ func (s *pexState) Add(c *PeerConn) { s.append(&e) } -func (s *pexState) Drop(c *PeerConn) { +func (s *pexState) Drop(c *PeerConn, lockPeer bool) { if !c.pex.Listed { // skip connections which were not previously Added return } - e, err := c.pexEvent(pexDrop) + e, err := c.pexEvent(pexDrop, lockPeer) if err != nil { return } diff --git a/pex_test.go b/pex_test.go index 089e0df2c5..83e0bc00e7 100644 --- a/pex_test.go +++ b/pex_test.go @@ -39,9 +39,9 @@ func TestPexReset(t *testing.T) { {Peer: Peer{RemoteAddr: addrs[1]}}, {Peer: Peer{RemoteAddr: addrs[2]}}, } - s.Add(&conns[0]) - s.Add(&conns[1]) - s.Drop(&conns[0]) + s.Add(&conns[0], true) + s.Add(&conns[1], true) + s.Drop(&conns[0], true) s.Reset() targ := new(pexState) require.EqualValues(t, targ, s) @@ -72,7 +72,7 @@ var testcases = []struct { in: func() *pexState { s := new(pexState) nullAddr := &net.TCPAddr{} - s.Add(&PeerConn{Peer: Peer{RemoteAddr: nullAddr}}) + s.Add(&PeerConn{Peer: Peer{RemoteAddr: nullAddr}}, true) return s }(), targ: pp.PexMsg{}, @@ -82,7 +82,7 @@ var testcases = []struct { in: func() *pexState { nullAddr := &net.TCPAddr{} s := new(pexState) - s.Drop(&PeerConn{Peer: Peer{RemoteAddr: nullAddr}, pex: pexConnState{Listed: true}}) + s.Drop(&PeerConn{Peer: Peer{RemoteAddr: nullAddr}, pex: pexConnState{Listed: true}}, true) return s }(), targ: pp.PexMsg{}, @@ -91,10 +91,10 @@ var testcases = []struct { name: "add4", in: func() *pexState { s := new(pexState) - s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[0]}}) - s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[1], outgoing: true}}) - s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[2], outgoing: true}}) - s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[3]}}) + s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[0]}}, true) + s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[1], outgoing: true}}, true) + s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[2], outgoing: true}}, true) + s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[3]}}, true) return s }(), targ: pp.PexMsg{ @@ -114,8 +114,8 @@ var testcases = []struct { name: "drop2", in: func() *pexState { s := &pexState{nc: pexTargAdded + 2} - s.Drop(&PeerConn{Peer: Peer{RemoteAddr: addrs[0]}, pex: pexConnState{Listed: true}}) - s.Drop(&PeerConn{Peer: Peer{RemoteAddr: addrs[2]}, pex: pexConnState{Listed: true}}) + s.Drop(&PeerConn{Peer: Peer{RemoteAddr: addrs[0]}, pex: pexConnState{Listed: true}}, true) + s.Drop(&PeerConn{Peer: Peer{RemoteAddr: addrs[2]}, pex: pexConnState{Listed: true}}, true) return s }(), targ: pp.PexMsg{ @@ -136,10 +136,10 @@ var testcases = []struct { {Peer: Peer{RemoteAddr: addrs[2]}}, } s := &pexState{nc: pexTargAdded} - s.Add(&conns[0]) - s.Add(&conns[1]) - s.Drop(&conns[0]) - s.Drop(&conns[2]) // to be ignored: it wasn't added + s.Add(&conns[0], true) + s.Add(&conns[1], true) + s.Drop(&conns[0], true) + s.Drop(&conns[2], true) // to be ignored: it wasn't added return s }(), targ: pp.PexMsg{ @@ -158,12 +158,12 @@ var testcases = []struct { {Peer: Peer{RemoteAddr: addrs[2]}}, } s := new(pexState) - s.Add(&conns[0]) - s.Add(&conns[1]) - s.Add(&conns[2]) - s.Drop(&conns[0]) // on hold: s.nc < pexTargAdded - s.Drop(&conns[2]) - s.Drop(&conns[1]) + s.Add(&conns[0], true) + s.Add(&conns[1], true) + s.Add(&conns[2], true) + s.Drop(&conns[0], true) // on hold: s.nc < pexTargAdded + s.Drop(&conns[2], true) + s.Drop(&conns[1], true) return s }(), targ: pp.PexMsg{ @@ -186,9 +186,9 @@ var testcases = []struct { {Peer: Peer{RemoteAddr: addrs[1]}}, } s := &pexState{nc: pexTargAdded - 1} - s.Add(&conns[0]) - s.Drop(&conns[0]) // on hold: s.nc < pexTargAdded - s.Add(&conns[1]) // unholds the above + s.Add(&conns[0], true) + s.Drop(&conns[0], true) // on hold: s.nc < pexTargAdded + s.Add(&conns[1], true) // unholds the above return s }(), targ: pp.PexMsg{ @@ -202,7 +202,7 @@ var testcases = []struct { name: "followup", in: func() *pexState { s := new(pexState) - s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[0]}}) + s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[0]}}, true) return s }(), targ: pp.PexMsg{ @@ -212,7 +212,7 @@ var testcases = []struct { Added6Flags: []pp.PexPeerFlags{0}, }, update: func(s *pexState) { - s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[1]}}) + s.Add(&PeerConn{Peer: Peer{RemoteAddr: addrs[1]}}, true) }, targ1: pp.PexMsg{ Added6: krpc.CompactIPv6NodeAddrs{ @@ -296,7 +296,7 @@ func TestPexInitialNoCutoff(t *testing.T) { c := addrgen(n) for addr := range c { - s.Add(&PeerConn{Peer: Peer{RemoteAddr: addr}}) + s.Add(&PeerConn{Peer: Peer{RemoteAddr: addr}}, true) } m, _ := s.Genmsg(nil) @@ -313,7 +313,7 @@ func benchmarkPexInitialN(b *testing.B, npeers int) { var s pexState c := addrgen(npeers) for addr := range c { - s.Add(&PeerConn{Peer: Peer{RemoteAddr: addr}}) + s.Add(&PeerConn{Peer: Peer{RemoteAddr: addr}}, true) s.Genmsg(nil) } } diff --git a/torrent.go b/torrent.go index b78dbaef76..83da06fa31 100644 --- a/torrent.go +++ b/torrent.go @@ -1940,7 +1940,7 @@ func (t *Torrent) deletePeerConn(c *PeerConn, lock bool) (ret bool) { func() { c.mu.Lock() defer c.mu.Unlock() - t.pex.Drop(c) + t.pex.Drop(c, false) }() } } @@ -2464,7 +2464,7 @@ func (t *Torrent) addPeerConn(c *PeerConn, lockTorrent bool) (err error) { // We'll never receive the "p" extended handshake parameter. if !t.cl.config.DisablePEX && !c.PeerExtensionBytes.SupportsExtended() { c.mu.Lock() - t.pex.Add(c) + t.pex.Add(c, false) c.mu.Unlock() } @@ -3484,7 +3484,7 @@ func (t *Torrent) peerConnsWithDialAddrPort(target netip.AddrPort, lock bool) (r } for pc := range t.conns { - dialAddr, err := pc.remoteDialAddrPort() + dialAddr, err := pc.remoteDialAddrPort(true) if err != nil { continue } @@ -3553,7 +3553,7 @@ func (t *Torrent) handleReceivedUtHolepunchMsg(msg utHolepunch.Msg, sender *Peer case utHolepunch.Rendezvous: t.logger.Printf("got holepunch rendezvous request for %v from %p", msg.AddrPort, sender) sendMsg := sendUtHolepunchMsg - senderAddrPort, err := sender.remoteDialAddrPort() + senderAddrPort, err := sender.remoteDialAddrPort(true) if err != nil { sender.logger.Levelf( log.Warning,