diff --git a/client/pkg/transport/listener.go b/client/pkg/transport/listener.go index ec69f4e2ad3d..69c23e50d35a 100644 --- a/client/pkg/transport/listener.go +++ b/client/pkg/transport/listener.go @@ -264,9 +264,10 @@ func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertVali NotBefore: time.Now(), NotAfter: time.Now().Add(time.Duration(selfSignedCertValidity) * 365 * (24 * time.Hour)), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCRLSign, ExtKeyUsage: append([]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, additionalUsages...), BasicConstraintsValid: true, + IsCA: true, } if info.Logger != nil { diff --git a/client/pkg/transport/listener_test.go b/client/pkg/transport/listener_test.go index 836dca998f8e..5e28f24d3a6e 100644 --- a/client/pkg/transport/listener_test.go +++ b/client/pkg/transport/listener_test.go @@ -15,12 +15,17 @@ package transport import ( + "crypto/rand" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" + "math/big" "net" "net/http" "os" + "path/filepath" "testing" "time" @@ -573,3 +578,153 @@ func TestSocktOptsEmpty(t *testing.T) { } } } + +// TestNewListenerWithACRLFile tests when a revocation list is present. +func TestNewListenerWithACRLFile(t *testing.T) { + clientTLSInfo, err := createSelfCertEx(t, "127.0.0.1", x509.ExtKeyUsageClientAuth) + if err != nil { + t.Fatalf("unable to create client cert: %v", err) + } + loaded, err := os.ReadFile(clientTLSInfo.CertFile) + if err != nil { + t.Fatalf("unable to read client cert: %v", err) + } + block, _ := pem.Decode(loaded) + clientCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("unable to parse client cert: %v", err) + } + + tests := map[string]struct { + overrideCRLFileName bool + expectHanshakeError bool + crlFile string + revokedCertificates []pkix.RevokedCertificate + revocationListContents []byte + }{ + "empty revocation list": { + expectHanshakeError: false, + }, + "invalid CRL path": { + overrideCRLFileName: true, + crlFile: "@badname", + expectHanshakeError: true, + }, + "client cert is revoked": { + expectHanshakeError: true, + revokedCertificates: []pkix.RevokedCertificate{ + { + SerialNumber: clientCert.SerialNumber, + RevocationTime: time.Now(), + }, + }, + }, + "invalid CRL file content": { + expectHanshakeError: true, + revocationListContents: []byte("@invalidcontent"), + }, + } + + for testName, test := range tests { + t.Run(testName, func(t *testing.T) { + tmpdir := t.TempDir() + tlsInfo, err := createSelfCert(t) + if err != nil { + t.Fatalf("unable to create server cert: %v", err) + } + tlsInfo.TrustedCAFile = clientTLSInfo.CertFile + + crlFile := filepath.Join(tmpdir, "revoked.r0") + if test.overrideCRLFileName { + tlsInfo.CRLFile = test.crlFile + } else { + tlsInfo.CRLFile = crlFile + } + + loaded, err := os.ReadFile(tlsInfo.CertFile) + if err != nil { + t.Fatalf("unable to read server cert: %v", err) + } + block, _ := pem.Decode(loaded) + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("unable to decode server cert: %v", err) + } + + loaded, err = os.ReadFile(tlsInfo.KeyFile) + if err != nil { + t.Fatalf("unable to read server key: %v", err) + } + block, _ = pem.Decode(loaded) + key, err := x509.ParseECPrivateKey(block.Bytes) + if err != nil { + t.Fatalf("unable to parse server key: %v", err) + } + + revocationListContents := test.revocationListContents + if len(revocationListContents) == 0 { + tmpl := &x509.RevocationList{ + RevokedCertificates: test.revokedCertificates, + ThisUpdate: time.Now(), + NextUpdate: time.Now().Add(time.Hour), + Number: big.NewInt(1), + } + revocationListContents, err = x509.CreateRevocationList(rand.Reader, tmpl, cert, key) + if err != nil { + t.Fatalf("unable to create revocation list: %v", err) + } + } + + if err := os.WriteFile(crlFile, revocationListContents, 0600); err != nil { + t.Fatalf("unable to write revocation list: %v", err) + } + + chHandshakeFailure := make(chan error, 1) + tlsInfo.HandshakeFailure = func(_ *tls.Conn, err error) { + if err != nil { + chHandshakeFailure <- err + } + } + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(cert) + + clientCert, err := tls.LoadX509KeyPair(clientTLSInfo.CertFile, clientTLSInfo.KeyFile) + if err != nil { + t.Fatalf("unable to create peer cert: %v", err) + } + + ln, err := NewListener("127.0.0.1:0", "https", tlsInfo) + defer ln.Close() + if err != nil { + t.Fatalf("unable to start listener: %v", err) + } + + tlsConfig := &tls.Config{} + tlsConfig.InsecureSkipVerify = false + tlsConfig.Certificates = []tls.Certificate{clientCert} + tlsConfig.RootCAs = rootCAs + + tr := &http.Transport{TLSClientConfig: tlsConfig} + cli := &http.Client{Transport: tr} + go cli.Get("https://" + ln.Addr().String()) + + chAcceptConn := make(chan net.Conn, 1) + go func() { + conn, _ := ln.Accept() + chAcceptConn <- conn + }() + + select { + case err := <-chHandshakeFailure: + if !test.expectHanshakeError { + t.Errorf("expecting no handshake error, got: %v", err) + } + case <-chAcceptConn: + if test.expectHanshakeError { + t.Errorf("expecting hanshake error, got nothing") + } + } + }) + } +}