From f954681e6981b9fd6e5a95529d2e964cf97ad738 Mon Sep 17 00:00:00 2001 From: abhishek818 Date: Wed, 17 Jul 2024 15:10:12 +0530 Subject: [PATCH 1/7] Support multiple LDAP servers in a auth source (#6898) Signed-off-by: abhishek818 --- cmd/admin_auth_ldap.go | 2 +- cmd/admin_auth_ldap_test.go | 16 +++---- routers/web/admin/auths.go | 2 +- services/auth/source/ldap/source.go | 2 +- services/auth/source/ldap/source_search.go | 56 +++++++++++++++------- 5 files changed, 49 insertions(+), 29 deletions(-) diff --git a/cmd/admin_auth_ldap.go b/cmd/admin_auth_ldap.go index e3c81809f8d2..1bef2f00daa1 100644 --- a/cmd/admin_auth_ldap.go +++ b/cmd/admin_auth_ldap.go @@ -207,7 +207,7 @@ func parseLdapConfig(c *cli.Context, config *ldap.Source) error { config.Name = c.String("name") } if c.IsSet("host") { - config.Host = c.String("host") + config.HostList = c.String("hostlist") } if c.IsSet("port") { config.Port = c.Int("port") diff --git a/cmd/admin_auth_ldap_test.go b/cmd/admin_auth_ldap_test.go index 7791f3a9cc14..e987782e61da 100644 --- a/cmd/admin_auth_ldap_test.go +++ b/cmd/admin_auth_ldap_test.go @@ -59,7 +59,7 @@ func TestAddLdapBindDn(t *testing.T) { IsSyncEnabled: true, Cfg: &ldap.Source{ Name: "ldap (via Bind DN) source full", - Host: "ldap-bind-server full", + HostList: "ldap-bind-server full", Port: 9876, SecurityProtocol: ldap.SecurityProtocol(1), SkipVerify: true, @@ -99,7 +99,7 @@ func TestAddLdapBindDn(t *testing.T) { IsActive: true, Cfg: &ldap.Source{ Name: "ldap (via Bind DN) source min", - Host: "ldap-bind-server min", + HostList: "ldap-bind-server min", Port: 1234, SecurityProtocol: ldap.SecurityProtocol(0), UserBase: "ou=Users,dc=min-domain-bind,dc=org", @@ -280,7 +280,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { IsActive: false, Cfg: &ldap.Source{ Name: "ldap (simple auth) source full", - Host: "ldap-simple-server full", + HostList: "ldap-simple-server full", Port: 987, SecurityProtocol: ldap.SecurityProtocol(2), SkipVerify: true, @@ -317,7 +317,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { IsActive: true, Cfg: &ldap.Source{ Name: "ldap (simple auth) source min", - Host: "ldap-simple-server min", + HostList: "ldap-simple-server min", Port: 123, SecurityProtocol: ldap.SecurityProtocol(0), UserDN: "cn=%s,ou=Users,dc=min-domain-simple,dc=org", @@ -526,7 +526,7 @@ func TestUpdateLdapBindDn(t *testing.T) { IsSyncEnabled: true, Cfg: &ldap.Source{ Name: "ldap (via Bind DN) source full", - Host: "ldap-bind-server full", + HostList: "ldap-bind-server full", Port: 9876, SecurityProtocol: ldap.SecurityProtocol(1), SkipVerify: true, @@ -630,7 +630,7 @@ func TestUpdateLdapBindDn(t *testing.T) { authSource: &auth.Source{ Type: auth.LDAP, Cfg: &ldap.Source{ - Host: "ldap-server", + HostList: "ldap-server", }, }, }, @@ -978,7 +978,7 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { IsActive: false, Cfg: &ldap.Source{ Name: "ldap (simple auth) source full", - Host: "ldap-simple-server full", + HostList: "ldap-simple-server full", Port: 987, SecurityProtocol: ldap.SecurityProtocol(2), SkipVerify: true, @@ -1078,7 +1078,7 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { authSource: &auth.Source{ Type: auth.DLDAP, Cfg: &ldap.Source{ - Host: "ldap-server", + HostList: "ldap-server", }, }, }, diff --git a/routers/web/admin/auths.go b/routers/web/admin/auths.go index 3b89be0f8fc2..4de0bb277a49 100644 --- a/routers/web/admin/auths.go +++ b/routers/web/admin/auths.go @@ -121,7 +121,7 @@ func parseLDAPConfig(form forms.AuthenticationForm) *ldap.Source { } return &ldap.Source{ Name: form.Name, - Host: form.Host, + HostList: form.Host, Port: form.Port, SecurityProtocol: ldap.SecurityProtocol(form.SecurityProtocol), SkipVerify: form.SkipVerify, diff --git a/services/auth/source/ldap/source.go b/services/auth/source/ldap/source.go index dc4cb2c94031..8f986db1805c 100644 --- a/services/auth/source/ldap/source.go +++ b/services/auth/source/ldap/source.go @@ -25,7 +25,7 @@ import ( // Source Basic LDAP authentication service type Source struct { Name string // canonical name (ie. corporate.ad) - Host string // LDAP host + HostList string // list containing LDAP host(s) Port int // port number SecurityProtocol SecurityProtocol SkipVerify bool diff --git a/services/auth/source/ldap/source_search.go b/services/auth/source/ldap/source_search.go index 2a61386ae106..9b899e7eb9e0 100644 --- a/services/auth/source/ldap/source_search.go +++ b/services/auth/source/ldap/source_search.go @@ -10,6 +10,7 @@ import ( "net" "strconv" "strings" + "time" "code.gitea.io/gitea/modules/container" "code.gitea.io/gitea/modules/log" @@ -111,28 +112,47 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) { func dial(source *Source) (*ldap.Conn, error) { log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify) - tlsConfig := &tls.Config{ - ServerName: source.Host, - InsecureSkipVerify: source.SkipVerify, - } + ldap.DefaultTimeout = time.Second * 15 + // HostList is a list of hosts separated by commas + hostList := strings.Split(source.HostList, ",") - if source.SecurityProtocol == SecurityProtocolLDAPS { - return ldap.DialTLS("tcp", net.JoinHostPort(source.Host, strconv.Itoa(source.Port)), tlsConfig) - } + for _, host := range hostList { + tlsConfig := &tls.Config{ + ServerName: host, + InsecureSkipVerify: source.SkipVerify, + } - conn, err := ldap.Dial("tcp", net.JoinHostPort(source.Host, strconv.Itoa(source.Port))) - if err != nil { - return nil, fmt.Errorf("error during Dial: %w", err) - } + if source.SecurityProtocol == SecurityProtocolLDAPS { + conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) + + if err != nil { + // Connection failed, try again with the next host. + log.Trace("error during Dial for host %s: %w", host, err) + continue + } + conn.SetTimeout(time.Second * 10) - if source.SecurityProtocol == SecurityProtocolStartTLS { - if err = conn.StartTLS(tlsConfig); err != nil { - conn.Close() - return nil, fmt.Errorf("error during StartTLS: %w", err) + return conn, err + } + + conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port))) + if err != nil { + log.Trace("error during Dial for host %s: %w", host, err) + continue + } + conn.SetTimeout(time.Second * 10) + + if source.SecurityProtocol == SecurityProtocolStartTLS { + if err = conn.StartTLS(tlsConfig); err != nil { + conn.Close() + log.Trace("error during StartTLS for host %s: %w", host, err) + continue + } } } - return conn, nil + // All servers were unreachable + return nil, fmt.Errorf("dial failed for all provided servers: %s", hostList) } func bindUser(l *ldap.Conn, userDN, passwd string) error { @@ -257,7 +277,7 @@ func (source *Source) SearchEntry(name, passwd string, directBind bool) *SearchR } l, err := dial(source) if err != nil { - log.Error("LDAP Connect error, %s:%v", source.Host, err) + log.Error("LDAP Connect error, %s:%v", source.HostList, err) source.Enabled = false return nil } @@ -421,7 +441,7 @@ func (source *Source) UsePagedSearch() bool { func (source *Source) SearchEntries() ([]*SearchResult, error) { l, err := dial(source) if err != nil { - log.Error("LDAP Connect error, %s:%v", source.Host, err) + log.Error("LDAP Connect error, %s:%v", source.HostList, err) source.Enabled = false return nil, err } From b95b9a85972fdfcdb631a7da6ec41f563f9586c9 Mon Sep 17 00:00:00 2001 From: abhishek818 Date: Thu, 18 Jul 2024 13:17:36 +0530 Subject: [PATCH 2/7] rename ldap' cli flag 'host' to 'host-list' and fix tests (#6898) Signed-off-by: abhishek818 --- cmd/admin_auth_ldap.go | 12 +++--- cmd/admin_auth_ldap_test.go | 46 +++++++++++----------- services/auth/source/ldap/README.md | 3 +- services/auth/source/ldap/source_search.go | 5 ++- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/cmd/admin_auth_ldap.go b/cmd/admin_auth_ldap.go index 1bef2f00daa1..dd435cc1c220 100644 --- a/cmd/admin_auth_ldap.go +++ b/cmd/admin_auth_ldap.go @@ -46,8 +46,8 @@ var ( Usage: "Disable TLS verification.", }, &cli.StringFlag{ - Name: "host", - Usage: "The address where the LDAP server can be reached.", + Name: "host-list", + Usage: "List of addresses where the LDAP server(s) can be reached.", }, &cli.IntFlag{ Name: "port", @@ -206,8 +206,8 @@ func parseLdapConfig(c *cli.Context, config *ldap.Source) error { if c.IsSet("name") { config.Name = c.String("name") } - if c.IsSet("host") { - config.HostList = c.String("hostlist") + if c.IsSet("host-list") { + config.HostList = c.String("host-list") } if c.IsSet("port") { config.Port = c.Int("port") @@ -308,7 +308,7 @@ func (a *authService) getAuthSource(ctx context.Context, c *cli.Context, authTyp // addLdapBindDn adds a new LDAP via Bind DN authentication source. func (a *authService) addLdapBindDn(c *cli.Context) error { - if err := argsSet(c, "name", "security-protocol", "host", "port", "user-search-base", "user-filter", "email-attribute"); err != nil { + if err := argsSet(c, "name", "security-protocol", "host-list", "port", "user-search-base", "user-filter", "email-attribute"); err != nil { return err } @@ -359,7 +359,7 @@ func (a *authService) updateLdapBindDn(c *cli.Context) error { // addLdapSimpleAuth adds a new LDAP (simple auth) authentication source. func (a *authService) addLdapSimpleAuth(c *cli.Context) error { - if err := argsSet(c, "name", "security-protocol", "host", "port", "user-dn", "user-filter", "email-attribute"); err != nil { + if err := argsSet(c, "name", "security-protocol", "host-list", "port", "user-dn", "user-filter", "email-attribute"); err != nil { return err } diff --git a/cmd/admin_auth_ldap_test.go b/cmd/admin_auth_ldap_test.go index e987782e61da..0539a15e4ccf 100644 --- a/cmd/admin_auth_ldap_test.go +++ b/cmd/admin_auth_ldap_test.go @@ -34,7 +34,7 @@ func TestAddLdapBindDn(t *testing.T) { "--not-active", "--security-protocol", "ldaps", "--skip-tls-verify", - "--host", "ldap-bind-server full", + "--host-list", "ldap-bind-server full", "--port", "9876", "--user-search-base", "ou=Users,dc=full-domain-bind,dc=org", "--user-filter", "(memberOf=cn=user-group,ou=example,dc=full-domain-bind,dc=org)", @@ -87,7 +87,7 @@ func TestAddLdapBindDn(t *testing.T) { "ldap-test", "--name", "ldap (via Bind DN) source min", "--security-protocol", "unencrypted", - "--host", "ldap-bind-server min", + "--host-list", "ldap-bind-server min", "--port", "1234", "--user-search-base", "ou=Users,dc=min-domain-bind,dc=org", "--user-filter", "(memberOf=cn=user-group,ou=example,dc=min-domain-bind,dc=org)", @@ -115,7 +115,7 @@ func TestAddLdapBindDn(t *testing.T) { "ldap-test", "--name", "ldap (via Bind DN) source", "--security-protocol", "zzzzz", - "--host", "ldap-server", + "--host-list", "ldap-server", "--port", "1234", "--user-search-base", "ou=Users,dc=domain,dc=org", "--user-filter", "(memberOf=cn=user-group,ou=example,dc=domain,dc=org)", @@ -128,7 +128,7 @@ func TestAddLdapBindDn(t *testing.T) { args: []string{ "ldap-test", "--security-protocol", "unencrypted", - "--host", "ldap-server", + "--host-list", "ldap-server", "--port", "1234", "--user-search-base", "ou=Users,dc=domain,dc=org", "--user-filter", "(memberOf=cn=user-group,ou=example,dc=domain,dc=org)", @@ -141,7 +141,7 @@ func TestAddLdapBindDn(t *testing.T) { args: []string{ "ldap-test", "--name", "ldap (via Bind DN) source", - "--host", "ldap-server", + "--host-list", "ldap-server", "--port", "1234", "--user-search-base", "ou=Users,dc=domain,dc=org", "--user-filter", "(memberOf=cn=user-group,ou=example,dc=domain,dc=org)", @@ -160,7 +160,7 @@ func TestAddLdapBindDn(t *testing.T) { "--user-filter", "(memberOf=cn=user-group,ou=example,dc=domain,dc=org)", "--email-attribute", "mail", }, - errMsg: "host is not set", + errMsg: "host-list is not set", }, // case 6 { @@ -168,7 +168,7 @@ func TestAddLdapBindDn(t *testing.T) { "ldap-test", "--name", "ldap (via Bind DN) source", "--security-protocol", "unencrypted", - "--host", "ldap-server", + "--host-list", "ldap-server", "--user-search-base", "ou=Users,dc=domain,dc=org", "--user-filter", "(memberOf=cn=user-group,ou=example,dc=domain,dc=org)", "--email-attribute", "mail", @@ -181,7 +181,7 @@ func TestAddLdapBindDn(t *testing.T) { "ldap-test", "--name", "ldap (via Bind DN) source", "--security-protocol", "unencrypted", - "--host", "ldap-server", + "--host-list", "ldap-server", "--port", "1234", "--user-search-base", "ou=Users,dc=domain,dc=org", "--email-attribute", "mail", @@ -194,7 +194,7 @@ func TestAddLdapBindDn(t *testing.T) { "ldap-test", "--name", "ldap (via Bind DN) source", "--security-protocol", "unencrypted", - "--host", "ldap-server", + "--host-list", "ldap-server", "--port", "1234", "--user-search-base", "ou=Users,dc=domain,dc=org", "--user-filter", "(memberOf=cn=user-group,ou=example,dc=domain,dc=org)", @@ -260,7 +260,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { "--not-active", "--security-protocol", "starttls", "--skip-tls-verify", - "--host", "ldap-simple-server full", + "--host-list", "ldap-simple-server full", "--port", "987", "--user-search-base", "ou=Users,dc=full-domain-simple,dc=org", "--user-filter", "(&(objectClass=posixAccount)(full-simple-cn=%s))", @@ -305,7 +305,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { "ldap-test", "--name", "ldap (simple auth) source min", "--security-protocol", "unencrypted", - "--host", "ldap-simple-server min", + "--host-list", "ldap-simple-server min", "--port", "123", "--user-filter", "(&(objectClass=posixAccount)(min-simple-cn=%s))", "--email-attribute", "mail-simple min", @@ -333,7 +333,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { "ldap-test", "--name", "ldap (simple auth) source", "--security-protocol", "zzzzz", - "--host", "ldap-server", + "--host-list", "ldap-server", "--port", "123", "--user-filter", "(&(objectClass=posixAccount)(cn=%s))", "--email-attribute", "mail", @@ -346,7 +346,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { args: []string{ "ldap-test", "--security-protocol", "unencrypted", - "--host", "ldap-server", + "--host-list", "ldap-server", "--port", "123", "--user-filter", "(&(objectClass=posixAccount)(cn=%s))", "--email-attribute", "mail", @@ -359,7 +359,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { args: []string{ "ldap-test", "--name", "ldap (simple auth) source", - "--host", "ldap-server", + "--host-list", "ldap-server", "--port", "123", "--user-filter", "(&(objectClass=posixAccount)(cn=%s))", "--email-attribute", "mail", @@ -378,7 +378,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { "--email-attribute", "mail", "--user-dn", "cn=%s,ou=Users,dc=domain,dc=org", }, - errMsg: "host is not set", + errMsg: "host-list is not set", }, // case 6 { @@ -386,7 +386,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { "ldap-test", "--name", "ldap (simple auth) source", "--security-protocol", "unencrypted", - "--host", "ldap-server", + "--host-list", "ldap-server", "--user-filter", "(&(objectClass=posixAccount)(cn=%s))", "--email-attribute", "mail", "--user-dn", "cn=%s,ou=Users,dc=domain,dc=org", @@ -399,7 +399,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { "ldap-test", "--name", "ldap (simple auth) source", "--security-protocol", "unencrypted", - "--host", "ldap-server", + "--host-list", "ldap-server", "--port", "123", "--email-attribute", "mail", "--user-dn", "cn=%s,ou=Users,dc=domain,dc=org", @@ -412,7 +412,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { "ldap-test", "--name", "ldap (simple auth) source", "--security-protocol", "unencrypted", - "--host", "ldap-server", + "--host-list", "ldap-server", "--port", "123", "--user-filter", "(&(objectClass=posixAccount)(cn=%s))", "--user-dn", "cn=%s,ou=Users,dc=domain,dc=org", @@ -425,7 +425,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { "ldap-test", "--name", "ldap (simple auth) source", "--security-protocol", "unencrypted", - "--host", "ldap-server", + "--host-list", "ldap-server", "--port", "123", "--user-filter", "(&(objectClass=posixAccount)(cn=%s))", "--email-attribute", "mail", @@ -494,7 +494,7 @@ func TestUpdateLdapBindDn(t *testing.T) { "--not-active", "--security-protocol", "LDAPS", "--skip-tls-verify", - "--host", "ldap-bind-server full", + "--host-list", "ldap-bind-server full", "--port", "9876", "--user-search-base", "ou=Users,dc=full-domain-bind,dc=org", "--user-filter", "(memberOf=cn=user-group,ou=example,dc=full-domain-bind,dc=org)", @@ -625,7 +625,7 @@ func TestUpdateLdapBindDn(t *testing.T) { args: []string{ "ldap-test", "--id", "1", - "--host", "ldap-server", + "--host-list", "ldap-server", }, authSource: &auth.Source{ Type: auth.LDAP, @@ -957,7 +957,7 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--not-active", "--security-protocol", "starttls", "--skip-tls-verify", - "--host", "ldap-simple-server full", + "--host-list", "ldap-simple-server full", "--port", "987", "--user-search-base", "ou=Users,dc=full-domain-simple,dc=org", "--user-filter", "(&(objectClass=posixAccount)(full-simple-cn=%s))", @@ -1073,7 +1073,7 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { args: []string{ "ldap-test", "--id", "1", - "--host", "ldap-server", + "--host-list", "ldap-server", }, authSource: &auth.Source{ Type: auth.DLDAP, diff --git a/services/auth/source/ldap/README.md b/services/auth/source/ldap/README.md index 34c811703f65..ec09eee05d99 100644 --- a/services/auth/source/ldap/README.md +++ b/services/auth/source/ldap/README.md @@ -32,8 +32,9 @@ share the following fields: * A name to assign to the new method of authorization. * Host **(required)** - * The address where the LDAP server can be reached. + * The list of addresses where the LDAP server(s) can be reached. * Example: mydomain.com + * Example (with multiple server hosts): mydomain.com, myotherdomain.com, mytempdomain.com * Port **(required)** * The port to use when connecting to the server. diff --git a/services/auth/source/ldap/source_search.go b/services/auth/source/ldap/source_search.go index 9b899e7eb9e0..40016ed271d3 100644 --- a/services/auth/source/ldap/source_search.go +++ b/services/auth/source/ldap/source_search.go @@ -113,8 +113,11 @@ func dial(source *Source) (*ldap.Conn, error) { log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify) ldap.DefaultTimeout = time.Second * 15 + // Remove any extra spaces in HostList string + tempHostList := strings.ReplaceAll(source.HostList, " ", "") // HostList is a list of hosts separated by commas - hostList := strings.Split(source.HostList, ",") + hostList := strings.Split(tempHostList, ",") + // hostList := strings.Split(source.HostList, ",") for _, host := range hostList { tlsConfig := &tls.Config{ From 67235999db1fd834b99e20739e3a3b30a0b140e2 Mon Sep 17 00:00:00 2001 From: abhishek kumar gupta Date: Thu, 18 Jul 2024 09:12:00 +0000 Subject: [PATCH 3/7] run lint (#6898) Signed-off-by: abhishek kumar gupta --- services/auth/source/ldap/source_search.go | 1 - 1 file changed, 1 deletion(-) diff --git a/services/auth/source/ldap/source_search.go b/services/auth/source/ldap/source_search.go index 40016ed271d3..e785f3953ebc 100644 --- a/services/auth/source/ldap/source_search.go +++ b/services/auth/source/ldap/source_search.go @@ -127,7 +127,6 @@ func dial(source *Source) (*ldap.Conn, error) { if source.SecurityProtocol == SecurityProtocolLDAPS { conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) - if err != nil { // Connection failed, try again with the next host. log.Trace("error during Dial for host %s: %w", host, err) From 88ca8fabea6cc4c0ed6bb535547c021711bae25a Mon Sep 17 00:00:00 2001 From: abhishek818 Date: Tue, 23 Jul 2024 03:36:11 +0530 Subject: [PATCH 4/7] close failed connections (#6898) Signed-off-by: abhishek818 --- services/auth/source/ldap/source_search.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/services/auth/source/ldap/source_search.go b/services/auth/source/ldap/source_search.go index e785f3953ebc..ad5ebe365c68 100644 --- a/services/auth/source/ldap/source_search.go +++ b/services/auth/source/ldap/source_search.go @@ -129,6 +129,7 @@ func dial(source *Source) (*ldap.Conn, error) { conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) if err != nil { // Connection failed, try again with the next host. + conn.Close() log.Trace("error during Dial for host %s: %w", host, err) continue } @@ -139,6 +140,7 @@ func dial(source *Source) (*ldap.Conn, error) { conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port))) if err != nil { + conn.Close() log.Trace("error during Dial for host %s: %w", host, err) continue } From f2c4cae867536f6f2e2f8f3519691f933df2e388 Mon Sep 17 00:00:00 2001 From: abhishek818 Date: Wed, 24 Jul 2024 15:19:55 +0530 Subject: [PATCH 5/7] race the tcp connections (#6898) Signed-off-by: abhishek818 --- services/auth/source/ldap/source_search.go | 78 +++++++++++++--------- 1 file changed, 48 insertions(+), 30 deletions(-) diff --git a/services/auth/source/ldap/source_search.go b/services/auth/source/ldap/source_search.go index ad5ebe365c68..5ad2247e330a 100644 --- a/services/auth/source/ldap/source_search.go +++ b/services/auth/source/ldap/source_search.go @@ -112,50 +112,68 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) { func dial(source *Source) (*ldap.Conn, error) { log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify) - ldap.DefaultTimeout = time.Second * 15 + ldap.DefaultTimeout = time.Second * 10 // Remove any extra spaces in HostList string tempHostList := strings.ReplaceAll(source.HostList, " ", "") // HostList is a list of hosts separated by commas hostList := strings.Split(tempHostList, ",") - // hostList := strings.Split(source.HostList, ",") - for _, host := range hostList { - tlsConfig := &tls.Config{ - ServerName: host, - InsecureSkipVerify: source.SkipVerify, - } + type result struct { + conn *ldap.Conn + err error + } - if source.SecurityProtocol == SecurityProtocolLDAPS { - conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) - if err != nil { - // Connection failed, try again with the next host. - conn.Close() - log.Trace("error during Dial for host %s: %w", host, err) - continue + results := make(chan result, len(hostList)) + + for _, host := range hostList { + go func(host string) { + tlsConfig := &tls.Config{ + ServerName: host, + InsecureSkipVerify: source.SkipVerify, } - conn.SetTimeout(time.Second * 10) - return conn, err - } + var conn *ldap.Conn + var err error - conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port))) - if err != nil { - conn.Close() - log.Trace("error during Dial for host %s: %w", host, err) - continue - } - conn.SetTimeout(time.Second * 10) + if source.SecurityProtocol == SecurityProtocolLDAPS { + conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) + } else { + conn, err = ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port))) + if err == nil && source.SecurityProtocol == SecurityProtocolStartTLS { + err = conn.StartTLS(tlsConfig) + } + } - if source.SecurityProtocol == SecurityProtocolStartTLS { - if err = conn.StartTLS(tlsConfig); err != nil { - conn.Close() - log.Trace("error during StartTLS for host %s: %w", host, err) - continue + if err != nil { + if conn != nil { + conn.Close() + } + log.Trace("error during Dial for host %s: %w", host, err) + results <- result{nil, err} + return } + + conn.SetTimeout(time.Second * 10) + results <- result{conn, nil} + }(host) + } + + for range hostList { + r := <-results + if r.err == nil { + // Close other connections still in progress + go func() { + for range hostList { + r := <-results + if r.conn != nil { + r.conn.Close() + } + } + }() + return r.conn, nil } } - // All servers were unreachable return nil, fmt.Errorf("dial failed for all provided servers: %s", hostList) } From 8d9269ffb93651a73e536e9a072fdf2c024cb116 Mon Sep 17 00:00:00 2001 From: abhishek818 Date: Wed, 21 Aug 2024 01:24:58 +0530 Subject: [PATCH 6/7] fix integration test (#6898) rename host to hostlist in html template Signed-off-by: abhishek818 --- templates/admin/auth/edit.tmpl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/templates/admin/auth/edit.tmpl b/templates/admin/auth/edit.tmpl index 660f0d088154..0e669e60cae2 100644 --- a/templates/admin/auth/edit.tmpl +++ b/templates/admin/auth/edit.tmpl @@ -36,7 +36,7 @@
- +
From 999df82c7536408c967d3a3f02d61cf1187cee2a Mon Sep 17 00:00:00 2001 From: abhishek818 Date: Thu, 26 Sep 2024 18:09:09 +0530 Subject: [PATCH 7/7] refactor code (#6898) Signed-off-by: abhishek818 --- services/auth/source/ldap/source_search.go | 87 +++++++++++++--------- 1 file changed, 51 insertions(+), 36 deletions(-) diff --git a/services/auth/source/ldap/source_search.go b/services/auth/source/ldap/source_search.go index 5ad2247e330a..2b3cc091fd83 100644 --- a/services/auth/source/ldap/source_search.go +++ b/services/auth/source/ldap/source_search.go @@ -10,6 +10,7 @@ import ( "net" "strconv" "strings" + "sync" "time" "code.gitea.io/gitea/modules/container" @@ -32,6 +33,12 @@ type SearchResult struct { Groups container.Set[string] } +// DialResult : dial response +type DialResult struct { + conn *ldap.Conn + err error +} + func (source *Source) sanitizedUserQuery(username string) (string, bool) { // See http://tools.ietf.org/search/rfc4515 badCharacters := "\x00()*\\" @@ -109,6 +116,39 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) { return userDN, true } +func dialHost(host string, source *Source, results chan DialResult, wg *sync.WaitGroup) { + defer wg.Done() + + tlsConfig := &tls.Config{ + ServerName: host, + InsecureSkipVerify: source.SkipVerify, + } + + var conn *ldap.Conn + var err error + + if source.SecurityProtocol == SecurityProtocolLDAPS { + conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) + } else { + conn, err = ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port))) + if err == nil && source.SecurityProtocol == SecurityProtocolStartTLS { + err = conn.StartTLS(tlsConfig) + } + } + + if err != nil { + if conn != nil { + conn.Close() + } + log.Trace("error during Dial for host %s: %w", host, err) + results <- DialResult{nil, err} + return + } + + conn.SetTimeout(time.Second * 10) + results <- DialResult{conn, nil} +} + func dial(source *Source) (*ldap.Conn, error) { log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify) @@ -118,46 +158,21 @@ func dial(source *Source) (*ldap.Conn, error) { // HostList is a list of hosts separated by commas hostList := strings.Split(tempHostList, ",") - type result struct { - conn *ldap.Conn - err error - } - - results := make(chan result, len(hostList)) + results := make(chan DialResult, len(hostList)) + var wg sync.WaitGroup + // Race all connections for _, host := range hostList { - go func(host string) { - tlsConfig := &tls.Config{ - ServerName: host, - InsecureSkipVerify: source.SkipVerify, - } - - var conn *ldap.Conn - var err error - - if source.SecurityProtocol == SecurityProtocolLDAPS { - conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) - } else { - conn, err = ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port))) - if err == nil && source.SecurityProtocol == SecurityProtocolStartTLS { - err = conn.StartTLS(tlsConfig) - } - } - - if err != nil { - if conn != nil { - conn.Close() - } - log.Trace("error during Dial for host %s: %w", host, err) - results <- result{nil, err} - return - } - - conn.SetTimeout(time.Second * 10) - results <- result{conn, nil} - }(host) + wg.Add(1) + go dialHost(host, source, results, &wg) } + // Close the results channel after all goroutines finish + go func() { + wg.Wait() + close(results) + }() + for range hostList { r := <-results if r.err == nil {