diff --git a/cloudprober.go b/cloudprober.go index ac19ec77..bba2464d 100644 --- a/cloudprober.go +++ b/cloudprober.go @@ -27,7 +27,6 @@ import ( "fmt" "net" "net/http" - "net/url" "os" "strconv" "strings" @@ -83,37 +82,37 @@ func getServerHost(c *configpb.ProberConfig) string { return serverHost } -func parsePort(portStr string) (int64, error) { +func getDefaultServerPort(c *configpb.ProberConfig, l *logger.Logger) (int, error) { + if c.GetPort() != 0 { + return int(c.GetPort()), nil + } + + // If ServerPortEnvVar is defined, it will override the default + // server port. + portStr := os.Getenv(ServerPortEnvVar) + if portStr == "" { + return DefaultServerPort, nil + } + if strings.HasPrefix(portStr, "tcp://") { - u, err := url.Parse(portStr) - if err != nil { - return 0, err - } - if u.Port() == "" { - return 0, fmt.Errorf("no port specified in URL %s", portStr) - } - // u.Port() returns port as a string, thus it - // will be converted to int64 at the end. - portStr = u.Port() + l.Warningf("%s environment variable likely set by Kubernetes (to %s), ignoring it", ServerPortEnvVar, portStr) + return DefaultServerPort, nil } - return strconv.ParseInt(portStr, 10, 32) + + port, err := strconv.ParseInt(portStr, 10, 32) + if err != nil { + return 0, fmt.Errorf("failed to parse default port from the env var: %s=%s", ServerPortEnvVar, portStr) + } + + return int(port), nil } -func initDefaultServer(c *configpb.ProberConfig) (net.Listener, error) { +func initDefaultServer(c *configpb.ProberConfig, l *logger.Logger) (net.Listener, error) { serverHost := getServerHost(c) - serverPort := int(c.GetPort()) - if serverPort == 0 { - serverPort = DefaultServerPort - - // If ServerPortEnvVar is defined, it will override the default - // server port. - if portStr := os.Getenv(ServerPortEnvVar); portStr != "" { - port, err := parsePort(portStr) - if err != nil { - return nil, fmt.Errorf("failed to parse default port from the env var: %s=%s", ServerPortEnvVar, portStr) - } - serverPort = int(port) - } + serverPort, err := getDefaultServerPort(c, l) + + if err != nil { + return nil, err } ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", serverHost, serverPort)) @@ -161,7 +160,7 @@ func InitFromConfig(configFile string) error { // Start default HTTP server. It's used for profile handlers and // prometheus exporter. - ln, err := initDefaultServer(cfg) + ln, err := initDefaultServer(cfg, l) if err != nil { return err } diff --git a/cloudprober_test.go b/cloudprober_test.go index ee0a2a10..ec8577dc 100644 --- a/cloudprober_test.go +++ b/cloudprober_test.go @@ -1,4 +1,4 @@ -// Copyright 2019 Google Inc. +// Copyright 2019-2020 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,33 +15,72 @@ package cloudprober import ( - "strings" + "os" "testing" + + configpb "github.com/google/cloudprober/config/proto" + "google.golang.org/protobuf/proto" ) -func TestParsePort(t *testing.T) { - // test if it parses just a number in string format - port, _ := parsePort("1234") - expectedPort := int64(1234) - if port != expectedPort { - t.Errorf("parsePort(\"%d\") = %d; want %d", expectedPort, port, expectedPort) +func TestGetDefaultServerPort(t *testing.T) { + tests := []struct { + desc string + configPort int32 + envVar string + wantPort int + wantErr bool + }{ + { + desc: "use port from config", + configPort: 9316, + envVar: "3141", + wantPort: 9316, + }, + { + desc: "use default port", + configPort: 0, + envVar: "", + wantPort: DefaultServerPort, + }, + { + desc: "use port from env", + configPort: 0, + envVar: "3141", + wantPort: 3141, + }, + { + desc: "ignore kubernetes port", + configPort: 0, + envVar: "tcp://100.101.102.103:3141", + wantPort: 9313, + }, + { + desc: "error due to bad env var", + configPort: 0, + envVar: "a3141", + wantErr: true, + }, } - // test if it parses full URL - testStr := "tcp://10.1.1.4:9313" - port, _ = parsePort(testStr) - expectedPort = int64(9313) - if port != expectedPort { - t.Errorf("parsePort(\"%s\") = %d; want %d", testStr, port, expectedPort) - } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + os.Setenv(ServerPortEnvVar, test.envVar) + port, err := getDefaultServerPort(&configpb.ProberConfig{ + Port: proto.Int32(test.configPort), + }, nil) - // test if it detects absent port in URL - testStr = "tcp://10.1.1.4" - _, err := parsePort(testStr) - errStr := "no port specified in URL" - if err != nil && !strings.Contains(err.Error(), errStr) { - t.Errorf("parsePort(\"%s\") doesn't return \"%s\" error, however found error: %s", testStr, "no port specified in URL", err.Error()) - } else if err == nil { - t.Errorf("parsePort(\"%s\") should return \"%s\" error, however no errors found", testStr, "no port specified in URL") + if err != nil { + if !test.wantErr { + t.Errorf("Got unexpected error: %v", err) + } else { + return + } + } + + if port != test.wantPort { + t.Errorf("got port: %d, want port: %d", port, test.wantPort) + } + }) } + }