Skip to content

Commit

Permalink
race the tcp connections (#6898)
Browse files Browse the repository at this point in the history
Signed-off-by: abhishek818 <[email protected]>
  • Loading branch information
abhishek818 committed Jul 24, 2024
1 parent 88ca8fa commit f2c4cae
Showing 1 changed file with 48 additions and 30 deletions.
78 changes: 48 additions & 30 deletions services/auth/source/ldap/source_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,50 +112,68 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) {
func dial(source *Source) (*ldap.Conn, error) {
log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify)

ldap.DefaultTimeout = time.Second * 15
ldap.DefaultTimeout = time.Second * 10
// Remove any extra spaces in HostList string
tempHostList := strings.ReplaceAll(source.HostList, " ", "")
// HostList is a list of hosts separated by commas
hostList := strings.Split(tempHostList, ",")
// hostList := strings.Split(source.HostList, ",")

for _, host := range hostList {
tlsConfig := &tls.Config{
ServerName: host,
InsecureSkipVerify: source.SkipVerify,
}
type result struct {
conn *ldap.Conn
err error
}

if source.SecurityProtocol == SecurityProtocolLDAPS {
conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)
if err != nil {
// Connection failed, try again with the next host.
conn.Close()
log.Trace("error during Dial for host %s: %w", host, err)
continue
results := make(chan result, len(hostList))

for _, host := range hostList {
go func(host string) {
tlsConfig := &tls.Config{
ServerName: host,
InsecureSkipVerify: source.SkipVerify,
}
conn.SetTimeout(time.Second * 10)

return conn, err
}
var conn *ldap.Conn
var err error

conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
if err != nil {
conn.Close()
log.Trace("error during Dial for host %s: %w", host, err)
continue
}
conn.SetTimeout(time.Second * 10)
if source.SecurityProtocol == SecurityProtocolLDAPS {
conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)
} else {
conn, err = ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
if err == nil && source.SecurityProtocol == SecurityProtocolStartTLS {
err = conn.StartTLS(tlsConfig)
}
}

if source.SecurityProtocol == SecurityProtocolStartTLS {
if err = conn.StartTLS(tlsConfig); err != nil {
conn.Close()
log.Trace("error during StartTLS for host %s: %w", host, err)
continue
if err != nil {
if conn != nil {
conn.Close()
}
log.Trace("error during Dial for host %s: %w", host, err)
results <- result{nil, err}
return
}

conn.SetTimeout(time.Second * 10)
results <- result{conn, nil}
}(host)
}

for range hostList {
r := <-results
if r.err == nil {
// Close other connections still in progress
go func() {
for range hostList {
r := <-results
if r.conn != nil {
r.conn.Close()
}
}
}()
return r.conn, nil
}
}

// All servers were unreachable
return nil, fmt.Errorf("dial failed for all provided servers: %s", hostList)
}

Expand Down

0 comments on commit f2c4cae

Please sign in to comment.