Skip to content

Commit

Permalink
proxymap, various: distinguish between different protocols
Browse files Browse the repository at this point in the history
Previously, we were registering TCP and UDP connections in the same map,
which could result in erroneously removing a mapping if one of the two
connections completes while the other one is still active.

Add a "proto string" argument to these functions to avoid this.
Additionally, take the "proto" argument in LocalAPI, and plumb that
through from the CLI and add a new LocalClient method.

Updates tailscale/corp#20600

Signed-off-by: Andrew Dunham <[email protected]>
Change-Id: I35d5efaefdfbf4721e315b8ca123f0c8af9125fb
  • Loading branch information
andrew-d authored and Asutorufa committed Aug 23, 2024
1 parent 819d849 commit 081b0e2
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 30 deletions.
19 changes: 19 additions & 0 deletions client/tailscale/localclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ func decodeJSON[T any](b []byte) (ret T, err error) {
// WhoIs returns the owner of the remoteAddr, which must be an IP or IP:port.
//
// If not found, the error is ErrPeerNotFound.
//
// For connections proxied by tailscaled, this looks up the owner of the given
// address as TCP first, falling back to UDP; if you want to only check a
// specific address family, use WhoIsProto.
func (lc *LocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {
body, err := lc.get200(ctx, "/localapi/v0/whois?addr="+url.QueryEscape(remoteAddr))
if err != nil {
Expand Down Expand Up @@ -313,6 +317,21 @@ func (lc *LocalClient) WhoIsNodeKey(ctx context.Context, key key.NodePublic) (*a
return decodeJSON[*apitype.WhoIsResponse](body)
}

// WhoIsProto returns the owner of the remoteAddr, which must be an IP or
// IP:port, for the given protocol (tcp or udp).
//
// If not found, the error is ErrPeerNotFound.
func (lc *LocalClient) WhoIsProto(ctx context.Context, proto, remoteAddr string) (*apitype.WhoIsResponse, error) {
body, err := lc.get200(ctx, "/localapi/v0/whois?proto="+url.QueryEscape(proto)+"&addr="+url.QueryEscape(remoteAddr))
if err != nil {
if hs, ok := err.(httpStatusError); ok && hs.HTTPStatus == http.StatusNotFound {
return nil, ErrPeerNotFound
}
return nil, err
}
return decodeJSON[*apitype.WhoIsResponse](body)
}

// Goroutines returns a dump of the Tailscale daemon's current goroutines.
func (lc *LocalClient) Goroutines(ctx context.Context) ([]byte, error) {
return lc.get200(ctx, "/localapi/v0/goroutines")
Expand Down
6 changes: 4 additions & 2 deletions cmd/tailscale/cli/whois.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ var whoisCmd = &ffcli.Command{
FlagSet: func() *flag.FlagSet {
fs := newFlagSet("whois")
fs.BoolVar(&whoIsArgs.json, "json", false, "output in JSON format")
fs.StringVar(&whoIsArgs.proto, "proto", "", `protocol; one of "tcp" or "udp"; empty mans both `)
return fs
}(),
}

var whoIsArgs struct {
json bool // output in JSON format
json bool // output in JSON format
proto string // "tcp" or "udp"
}

func runWhoIs(ctx context.Context, args []string) error {
Expand All @@ -40,7 +42,7 @@ func runWhoIs(ctx context.Context, args []string) error {
} else if len(args) == 0 {
return errors.New("missing argument, expected one peer")
}
who, err := localClient.WhoIs(ctx, args[0])
who, err := localClient.WhoIsProto(ctx, whoIsArgs.proto, args[0])
if err != nil {
return err
}
Expand Down
26 changes: 23 additions & 3 deletions ipn/ipnlocal/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -995,8 +995,15 @@ func (b *LocalBackend) WhoIsNodeKey(k key.NodePublic) (n tailcfg.NodeView, u tai

// WhoIs reports the node and user who owns the node with the given IP:port.
// If the IP address is a Tailscale IP, the provided port may be 0.
//
// The 'proto' is used when looking up the IP:port in our proxy mapper; it
// tracks which local IP:ports correspond to connections proxied by tailscaled,
// and since tailscaled proxies both TCP and UDP, the 'proto' is needed to look
// up the correct IP:port based on the connection's protocol. If not provided,
// the lookup will be done for TCP and then UDP, in that order.
//
// If ok == true, n and u are valid.
func (b *LocalBackend) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
func (b *LocalBackend) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
var zero tailcfg.NodeView
b.mu.Lock()
defer b.mu.Unlock()
Expand All @@ -1005,7 +1012,20 @@ func (b *LocalBackend) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.
if !ok {
var ip netip.Addr
if ipp.Port() != 0 {
ip, ok = b.sys.ProxyMapper().WhoIsIPPort(ipp)
var protos []string
if proto != "" {
protos = []string{proto}
} else {
// If the user didn't specify a protocol, try all of them
protos = []string{"tcp", "udp"}
}

for _, tryproto := range protos {
ip, ok = b.sys.ProxyMapper().WhoIsIPPort(tryproto, ipp)
if ok {
break
}
}
}
if !ok {
return zero, u, false
Expand Down Expand Up @@ -5044,7 +5064,7 @@ func (dt *driveTransport) RoundTrip(req *http.Request) (resp *http.Response, err
dt.b.mu.Lock()
selfNodeKey := dt.b.netMap.SelfNode.Key().ShortString()
dt.b.mu.Unlock()
n, _, ok := dt.b.WhoIs(netip.MustParseAddrPort(req.URL.Host))
n, _, ok := dt.b.WhoIs("tcp", netip.MustParseAddrPort(req.URL.Host))
shareNodeKey := "unknown"
if ok {
shareNodeKey = string(n.Key().ShortString())
Expand Down
2 changes: 1 addition & 1 deletion ipn/ipnlocal/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ func TestWhoIs(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.q, func(t *testing.T) {
nv, up, ok := b.WhoIs(netip.MustParseAddrPort(tt.q))
nv, up, ok := b.WhoIs("", netip.MustParseAddrPort(tt.q))
var got tailcfg.NodeID
if ok {
got = nv.ID()
Expand Down
2 changes: 1 addition & 1 deletion ipn/ipnlocal/peerapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (pln *peerAPIListener) serve() {

func (pln *peerAPIListener) ServeConn(src netip.AddrPort, c net.Conn) {
logf := pln.lb.logf
peerNode, peerUser, ok := pln.lb.WhoIs(src)
peerNode, peerUser, ok := pln.lb.WhoIs("tcp", src)
if !ok {
logf("peerapi: unknown peer %v", src)
c.Close()
Expand Down
2 changes: 1 addition & 1 deletion ipn/ipnlocal/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ func (b *LocalBackend) addTailscaleIdentityHeaders(r *httputil.ProxyRequest) {
if !ok {
return
}
node, user, ok := b.WhoIs(c.SrcAddr)
node, user, ok := b.WhoIs("tcp", c.SrcAddr)
if !ok {
return // traffic from outside of Tailnet (funneled)
}
Expand Down
4 changes: 2 additions & 2 deletions ipn/localapi/localapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ func (h *Handler) serveWhoIs(w http.ResponseWriter, r *http.Request) {
// localBackendWhoIsMethods is the subset of ipn.LocalBackend as needed
// by the localapi WhoIs method.
type localBackendWhoIsMethods interface {
WhoIs(netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
WhoIs(string, netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
WhoIsNodeKey(key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
PeerCaps(netip.Addr) tailcfg.PeerCapMap
}
Expand Down Expand Up @@ -482,7 +482,7 @@ func (h *Handler) serveWhoIsWithBackend(w http.ResponseWriter, r *http.Request,
}
}
if ipp.IsValid() {
n, u, ok = b.WhoIs(ipp)
n, u, ok = b.WhoIs(r.FormValue("proto"), ipp)
}
} else {
http.Error(w, "missing 'addr' parameter", http.StatusBadRequest)
Expand Down
8 changes: 4 additions & 4 deletions ipn/localapi/localapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ func TestSetPushDeviceToken(t *testing.T) {
}

type whoIsBackend struct {
whoIs func(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
whoIs func(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
whoIsNodeKey func(key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
peerCaps map[netip.Addr]tailcfg.PeerCapMap
}

func (b whoIsBackend) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
return b.whoIs(ipp)
func (b whoIsBackend) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
return b.whoIs(proto, ipp)
}

func (b whoIsBackend) WhoIsNodeKey(k key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
Expand Down Expand Up @@ -143,7 +143,7 @@ func TestWhoIsArgTypes(t *testing.T) {
rec := httptest.NewRecorder()
t.Run(input, func(t *testing.T) {
b := whoIsBackend{
whoIs: func(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
whoIs: func(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
if !strings.Contains(input, ":") {
want := netip.MustParseAddrPort("100.101.102.103:0")
if ipp != want {
Expand Down
32 changes: 23 additions & 9 deletions proxymap/proxymap.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (
"net/netip"
"sync"
"time"

"tailscale.com/util/mak"
)

// Mapper tracks which localhost ip:ports correspond to which remote Tailscale
Expand All @@ -21,26 +19,39 @@ import (
// given localhost:port corresponds to.
type Mapper struct {
mu sync.Mutex
m map[netip.AddrPort]netip.Addr
m map[string]map[netip.AddrPort]netip.Addr // proto ("tcp", "udp") => ephemeral => tailscale IP
}

// RegisterIPPortIdentity registers a given node (identified by its
// Tailscale IP) as temporarily having the given IP:port for whois lookups.
//
// The IP:port is generally a localhost IP and an ephemeral port, used
// while proxying connections to localhost when tailscaled is running
// in netstack mode.
func (m *Mapper) RegisterIPPortIdentity(ipport netip.AddrPort, tsIP netip.Addr) {
//
// The proto is the network protocol that is being proxied; it must be "tcp" or
// "udp" (not e.g. "tcp4", "udp6", etc.)
func (m *Mapper) RegisterIPPortIdentity(proto string, ipport netip.AddrPort, tsIP netip.Addr) {
m.mu.Lock()
defer m.mu.Unlock()
mak.Set(&m.m, ipport, tsIP)
if m.m == nil {
m.m = make(map[string]map[netip.AddrPort]netip.Addr)
}
p, ok := m.m[proto]
if !ok {
p = make(map[netip.AddrPort]netip.Addr)
m.m[proto] = p
}
p[ipport] = tsIP
}

// UnregisterIPPortIdentity removes a temporary IP:port registration
// made previously by RegisterIPPortIdentity.
func (m *Mapper) UnregisterIPPortIdentity(ipport netip.AddrPort) {
func (m *Mapper) UnregisterIPPortIdentity(proto string, ipport netip.AddrPort) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.m, ipport)
p := m.m[proto]
delete(p, ipport) // safe to delete from a nil map
}

var whoIsSleeps = [...]time.Duration{
Expand All @@ -53,7 +64,7 @@ var whoIsSleeps = [...]time.Duration{

// WhoIsIPPort looks up an IP:port in the temporary registrations,
// and returns a matching Tailscale IP, if it exists.
func (m *Mapper) WhoIsIPPort(ipport netip.AddrPort) (tsIP netip.Addr, ok bool) {
func (m *Mapper) WhoIsIPPort(proto string, ipport netip.AddrPort) (tsIP netip.Addr, ok bool) {
// We currently have a registration race,
// https://github.com/tailscale/tailscale/issues/1616,
// so loop a few times for now waiting for the registration
Expand All @@ -62,7 +73,10 @@ func (m *Mapper) WhoIsIPPort(ipport netip.AddrPort) (tsIP netip.Addr, ok bool) {
for _, d := range whoIsSleeps {
time.Sleep(d)
m.mu.Lock()
tsIP, ok = m.m[ipport]
p, ok := m.m[proto]
if ok {
tsIP, ok = p[ipport]
}
m.mu.Unlock()
if ok {
return tsIP, true
Expand Down
4 changes: 2 additions & 2 deletions ssh/tailssh/tailssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ type ipnLocalBackend interface {
GetSSH_HostKeys() ([]gossh.Signer, error)
ShouldRunSSH() bool
NetMap() *netmap.NetworkMap
WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
DoNoiseRequest(req *http.Request) (*http.Response, error)
Dialer() *tsdial.Dialer
TailscaleVarRoot() string
Expand Down Expand Up @@ -604,7 +604,7 @@ func (c *conn) setInfo(ctx ssh.Context) error {
if !tsaddr.IsTailscaleIP(ci.src.Addr()) {
return fmt.Errorf("tailssh: rejecting non-Tailscale remote address %v", ci.src)
}
node, uprof, ok := c.srv.lb.WhoIs(ci.src)
node, uprof, ok := c.srv.lb.WhoIs("tcp", ci.src)
if !ok {
return fmt.Errorf("unknown Tailscale identity from src %v", ci.src)
}
Expand Down
6 changes: 5 additions & 1 deletion ssh/tailssh/tailssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,11 @@ func (ts *localState) NetMap() *netmap.NetworkMap {
}
}

func (ts *localState) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
func (ts *localState) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
if proto != "tcp" {
return tailcfg.NodeView{}, tailcfg.UserProfile{}, false
}

return (&tailcfg.Node{
ID: 2,
StableID: "peer-id",
Expand Down
8 changes: 4 additions & 4 deletions wgengine/netstack/netstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -1328,8 +1328,8 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.

backendLocalAddr := server.LocalAddr().(*net.TCPAddr)
backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort())
ns.pm.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
defer ns.pm.UnregisterIPPortIdentity(backendLocalIPPort)
ns.pm.RegisterIPPortIdentity("tcp", backendLocalIPPort, clientRemoteIP)
defer ns.pm.UnregisterIPPortIdentity("tcp", backendLocalIPPort)
connClosed := make(chan error, 2)
go func() {
_, err := io.Copy(server, client)
Expand Down Expand Up @@ -1533,7 +1533,7 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, clientAddr, dstAddr netip.Addr
ns.logf("could not get backend local IP:port from %v:%v", backendLocalAddr.IP, backendLocalAddr.Port)
}
if isLocal {
ns.pm.RegisterIPPortIdentity(backendLocalIPPort, dstAddr.Addr())
ns.pm.RegisterIPPortIdentity("udp", backendLocalIPPort, clientAddr.Addr())
}
ctx, cancel := context.WithCancel(context.Background())

Expand All @@ -1549,7 +1549,7 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, clientAddr, dstAddr netip.Addr
}
timer := time.AfterFunc(idleTimeout, func() {
if isLocal {
ns.pm.UnregisterIPPortIdentity(backendLocalIPPort)
ns.pm.UnregisterIPPortIdentity("udp", backendLocalIPPort)
}
ns.logf("netstack: UDP session between %s and %s timed out", backendListenAddr, backendRemoteAddr)
cancel()
Expand Down

0 comments on commit 081b0e2

Please sign in to comment.