From 75ee7298f19615491eebc2f1b03b5241ccdb8b0a Mon Sep 17 00:00:00 2001 From: abhishek818 Date: Wed, 17 Jul 2024 15:10:12 +0530 Subject: [PATCH] Support multiple LDAP servers in a auth source (#6898) Signed-off-by: abhishek818 --- cmd/admin_auth_ldap.go | 2 +- cmd/admin_auth_ldap_test.go | 16 +++---- routers/web/admin/auths.go | 2 +- services/auth/source/ldap/source.go | 2 +- services/auth/source/ldap/source_search.go | 56 +++++++++++++++------- 5 files changed, 49 insertions(+), 29 deletions(-) diff --git a/cmd/admin_auth_ldap.go b/cmd/admin_auth_ldap.go index e3c81809f8d24..1bef2f00daa1a 100644 --- a/cmd/admin_auth_ldap.go +++ b/cmd/admin_auth_ldap.go @@ -207,7 +207,7 @@ func parseLdapConfig(c *cli.Context, config *ldap.Source) error { config.Name = c.String("name") } if c.IsSet("host") { - config.Host = c.String("host") + config.HostList = c.String("hostlist") } if c.IsSet("port") { config.Port = c.Int("port") diff --git a/cmd/admin_auth_ldap_test.go b/cmd/admin_auth_ldap_test.go index 7791f3a9cc14b..e987782e61da8 100644 --- a/cmd/admin_auth_ldap_test.go +++ b/cmd/admin_auth_ldap_test.go @@ -59,7 +59,7 @@ func TestAddLdapBindDn(t *testing.T) { IsSyncEnabled: true, Cfg: &ldap.Source{ Name: "ldap (via Bind DN) source full", - Host: "ldap-bind-server full", + HostList: "ldap-bind-server full", Port: 9876, SecurityProtocol: ldap.SecurityProtocol(1), SkipVerify: true, @@ -99,7 +99,7 @@ func TestAddLdapBindDn(t *testing.T) { IsActive: true, Cfg: &ldap.Source{ Name: "ldap (via Bind DN) source min", - Host: "ldap-bind-server min", + HostList: "ldap-bind-server min", Port: 1234, SecurityProtocol: ldap.SecurityProtocol(0), UserBase: "ou=Users,dc=min-domain-bind,dc=org", @@ -280,7 +280,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { IsActive: false, Cfg: &ldap.Source{ Name: "ldap (simple auth) source full", - Host: "ldap-simple-server full", + HostList: "ldap-simple-server full", Port: 987, SecurityProtocol: ldap.SecurityProtocol(2), SkipVerify: true, @@ -317,7 +317,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { IsActive: true, Cfg: &ldap.Source{ Name: "ldap (simple auth) source min", - Host: "ldap-simple-server min", + HostList: "ldap-simple-server min", Port: 123, SecurityProtocol: ldap.SecurityProtocol(0), UserDN: "cn=%s,ou=Users,dc=min-domain-simple,dc=org", @@ -526,7 +526,7 @@ func TestUpdateLdapBindDn(t *testing.T) { IsSyncEnabled: true, Cfg: &ldap.Source{ Name: "ldap (via Bind DN) source full", - Host: "ldap-bind-server full", + HostList: "ldap-bind-server full", Port: 9876, SecurityProtocol: ldap.SecurityProtocol(1), SkipVerify: true, @@ -630,7 +630,7 @@ func TestUpdateLdapBindDn(t *testing.T) { authSource: &auth.Source{ Type: auth.LDAP, Cfg: &ldap.Source{ - Host: "ldap-server", + HostList: "ldap-server", }, }, }, @@ -978,7 +978,7 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { IsActive: false, Cfg: &ldap.Source{ Name: "ldap (simple auth) source full", - Host: "ldap-simple-server full", + HostList: "ldap-simple-server full", Port: 987, SecurityProtocol: ldap.SecurityProtocol(2), SkipVerify: true, @@ -1078,7 +1078,7 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { authSource: &auth.Source{ Type: auth.DLDAP, Cfg: &ldap.Source{ - Host: "ldap-server", + HostList: "ldap-server", }, }, }, diff --git a/routers/web/admin/auths.go b/routers/web/admin/auths.go index 3b89be0f8fc26..4de0bb277a497 100644 --- a/routers/web/admin/auths.go +++ b/routers/web/admin/auths.go @@ -121,7 +121,7 @@ func parseLDAPConfig(form forms.AuthenticationForm) *ldap.Source { } return &ldap.Source{ Name: form.Name, - Host: form.Host, + HostList: form.Host, Port: form.Port, SecurityProtocol: ldap.SecurityProtocol(form.SecurityProtocol), SkipVerify: form.SkipVerify, diff --git a/services/auth/source/ldap/source.go b/services/auth/source/ldap/source.go index dc4cb2c94031b..8f986db1805c0 100644 --- a/services/auth/source/ldap/source.go +++ b/services/auth/source/ldap/source.go @@ -25,7 +25,7 @@ import ( // Source Basic LDAP authentication service type Source struct { Name string // canonical name (ie. corporate.ad) - Host string // LDAP host + HostList string // list containing LDAP host(s) Port int // port number SecurityProtocol SecurityProtocol SkipVerify bool diff --git a/services/auth/source/ldap/source_search.go b/services/auth/source/ldap/source_search.go index 2a61386ae1061..9b899e7eb9e06 100644 --- a/services/auth/source/ldap/source_search.go +++ b/services/auth/source/ldap/source_search.go @@ -10,6 +10,7 @@ import ( "net" "strconv" "strings" + "time" "code.gitea.io/gitea/modules/container" "code.gitea.io/gitea/modules/log" @@ -111,28 +112,47 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) { func dial(source *Source) (*ldap.Conn, error) { log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify) - tlsConfig := &tls.Config{ - ServerName: source.Host, - InsecureSkipVerify: source.SkipVerify, - } + ldap.DefaultTimeout = time.Second * 15 + // HostList is a list of hosts separated by commas + hostList := strings.Split(source.HostList, ",") - if source.SecurityProtocol == SecurityProtocolLDAPS { - return ldap.DialTLS("tcp", net.JoinHostPort(source.Host, strconv.Itoa(source.Port)), tlsConfig) - } + for _, host := range hostList { + tlsConfig := &tls.Config{ + ServerName: host, + InsecureSkipVerify: source.SkipVerify, + } - conn, err := ldap.Dial("tcp", net.JoinHostPort(source.Host, strconv.Itoa(source.Port))) - if err != nil { - return nil, fmt.Errorf("error during Dial: %w", err) - } + if source.SecurityProtocol == SecurityProtocolLDAPS { + conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) + + if err != nil { + // Connection failed, try again with the next host. + log.Trace("error during Dial for host %s: %w", host, err) + continue + } + conn.SetTimeout(time.Second * 10) - if source.SecurityProtocol == SecurityProtocolStartTLS { - if err = conn.StartTLS(tlsConfig); err != nil { - conn.Close() - return nil, fmt.Errorf("error during StartTLS: %w", err) + return conn, err + } + + conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port))) + if err != nil { + log.Trace("error during Dial for host %s: %w", host, err) + continue + } + conn.SetTimeout(time.Second * 10) + + if source.SecurityProtocol == SecurityProtocolStartTLS { + if err = conn.StartTLS(tlsConfig); err != nil { + conn.Close() + log.Trace("error during StartTLS for host %s: %w", host, err) + continue + } } } - return conn, nil + // All servers were unreachable + return nil, fmt.Errorf("dial failed for all provided servers: %s", hostList) } func bindUser(l *ldap.Conn, userDN, passwd string) error { @@ -257,7 +277,7 @@ func (source *Source) SearchEntry(name, passwd string, directBind bool) *SearchR } l, err := dial(source) if err != nil { - log.Error("LDAP Connect error, %s:%v", source.Host, err) + log.Error("LDAP Connect error, %s:%v", source.HostList, err) source.Enabled = false return nil } @@ -421,7 +441,7 @@ func (source *Source) UsePagedSearch() bool { func (source *Source) SearchEntries() ([]*SearchResult, error) { l, err := dial(source) if err != nil { - log.Error("LDAP Connect error, %s:%v", source.Host, err) + log.Error("LDAP Connect error, %s:%v", source.HostList, err) source.Enabled = false return nil, err }