diff --git a/services/auth/source/ldap/source_search.go b/services/auth/source/ldap/source_search.go index ad5ebe365c68..5ad2247e330a 100644 --- a/services/auth/source/ldap/source_search.go +++ b/services/auth/source/ldap/source_search.go @@ -112,50 +112,68 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) { func dial(source *Source) (*ldap.Conn, error) { log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify) - ldap.DefaultTimeout = time.Second * 15 + ldap.DefaultTimeout = time.Second * 10 // Remove any extra spaces in HostList string tempHostList := strings.ReplaceAll(source.HostList, " ", "") // HostList is a list of hosts separated by commas hostList := strings.Split(tempHostList, ",") - // hostList := strings.Split(source.HostList, ",") - for _, host := range hostList { - tlsConfig := &tls.Config{ - ServerName: host, - InsecureSkipVerify: source.SkipVerify, - } + type result struct { + conn *ldap.Conn + err error + } - if source.SecurityProtocol == SecurityProtocolLDAPS { - conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) - if err != nil { - // Connection failed, try again with the next host. - conn.Close() - log.Trace("error during Dial for host %s: %w", host, err) - continue + results := make(chan result, len(hostList)) + + for _, host := range hostList { + go func(host string) { + tlsConfig := &tls.Config{ + ServerName: host, + InsecureSkipVerify: source.SkipVerify, } - conn.SetTimeout(time.Second * 10) - return conn, err - } + var conn *ldap.Conn + var err error - conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port))) - if err != nil { - conn.Close() - log.Trace("error during Dial for host %s: %w", host, err) - continue - } - conn.SetTimeout(time.Second * 10) + if source.SecurityProtocol == SecurityProtocolLDAPS { + conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) + } else { + conn, err = ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port))) + if err == nil && source.SecurityProtocol == SecurityProtocolStartTLS { + err = conn.StartTLS(tlsConfig) + } + } - if source.SecurityProtocol == SecurityProtocolStartTLS { - if err = conn.StartTLS(tlsConfig); err != nil { - conn.Close() - log.Trace("error during StartTLS for host %s: %w", host, err) - continue + if err != nil { + if conn != nil { + conn.Close() + } + log.Trace("error during Dial for host %s: %w", host, err) + results <- result{nil, err} + return } + + conn.SetTimeout(time.Second * 10) + results <- result{conn, nil} + }(host) + } + + for range hostList { + r := <-results + if r.err == nil { + // Close other connections still in progress + go func() { + for range hostList { + r := <-results + if r.conn != nil { + r.conn.Close() + } + } + }() + return r.conn, nil } } - // All servers were unreachable return nil, fmt.Errorf("dial failed for all provided servers: %s", hostList) }