From 16730625f73bdf95190d23548f4379856583b506 Mon Sep 17 00:00:00 2001 From: Clement Doucy Date: Sun, 22 Sep 2024 20:06:46 +0200 Subject: [PATCH] feat: Implemented ability to dynamically reload TLS certificates --- client.go | 48 +++++++++ pem_watcher.go | 124 ++++++++++++++++++++++ pem_watcher_test.go | 250 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 422 insertions(+) create mode 100644 pem_watcher.go create mode 100644 pem_watcher_test.go diff --git a/client.go b/client.go index 87a88da..ff45382 100644 --- a/client.go +++ b/client.go @@ -865,6 +865,15 @@ func (c *Client) SetRootCertificate(pemFilePath string) *Client { return c } +// SetRootCertificateWatcher enables dynamic reloading of one or more root certificates. +// It is designed for scenarios involving long-running Resty clients where certificates may be renewed. +// +// client.SetRootCertificateWatcher(&WatcherOptions{PemFilePath: "root-ca.crt"}) +func (c *Client) SetRootCertificateWatcher(options *WatcherOptions) *Client { + c.handleCAsWatcher("root", options) + return c +} + // SetRootCertificateFromString method helps to add one or more root certificates // into the Resty client // @@ -888,6 +897,15 @@ func (c *Client) SetClientRootCertificate(pemFilePath string) *Client { return c } +// SetClientRootCertificateWatcher enables dynamic reloading of one or more root certificates. +// It is designed for scenarios involving long-running Resty clients where certificates may be renewed. +// +// client.SetClientRootCertificateWatcher(&WatcherOptions{PemFilePath: "root-ca.crt"}) +func (c *Client) SetClientRootCertificateWatcher(options *WatcherOptions) *Client { + c.handleCAsWatcher("client", options) + return c +} + // SetClientRootCertificateFromString method helps to add one or more clients // root certificates into the Resty client // @@ -918,6 +936,36 @@ func (c *Client) handleCAs(scope string, permCerts []byte) { } } +func (c *Client) handleCAsWatcher(scope string, options *WatcherOptions) { + pw, err := newPemWatcher(options, c.log, c.Debug) + if err != nil { + c.log.Errorf("%v", err) + return + } + + tlsConfig, err := c.tlsConfig() + if err != nil { + c.log.Errorf("%v", err) + return + } + + c.OnBeforeRequest(func(client *Client, request *Request) error { + certPool, err := pw.CertPool() + if err != nil { + return err + } + + switch scope { + case "root": + tlsConfig.RootCAs = certPool + case "client": + tlsConfig.ClientCAs = certPool + } + + return nil + }) +} + // SetOutputDirectory method sets the output directory for saving HTTP responses in a file. // Resty creates one if the output directory does not exist. This setting is optional, // if you plan to use the absolute path in [Request.SetOutput] and can used together. diff --git a/pem_watcher.go b/pem_watcher.go new file mode 100644 index 0000000..a42dee2 --- /dev/null +++ b/pem_watcher.go @@ -0,0 +1,124 @@ +// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// resty source code and usage is governed by a MIT style +// license that can be found in the LICENSE file. + +package resty + +import ( + "crypto/x509" + "errors" + "os" + "time" +) + +const ( + defaultWatcherPoolingInterval = 1 * time.Minute +) + +// WatcherOptions struct is used to enable TLS Certificate hot reloading. +type WatcherOptions struct { + // PemFilePath is the path of the PEM file + PemFilePath string + + // PoolingInterval is the frequency at which resty will check if the PEM file needs to be reloaded. + // Default is 1 min. + PoolingInterval time.Duration +} + +type pemWatcher struct { + opt *WatcherOptions + + certPool *x509.CertPool + modTime time.Time + lastChecked time.Time + log Logger + debug bool +} + +func newPemWatcher(options *WatcherOptions, log Logger, debug bool) (*pemWatcher, error) { + if options.PemFilePath == "" { + return nil, errors.New("PemFilePath is required") + } + + if options.PoolingInterval == 0 { + options.PoolingInterval = defaultWatcherPoolingInterval + } + + cw := &pemWatcher{ + opt: options, + log: log, + debug: debug, + } + + if err := cw.checkRefresh(); err != nil { + return nil, err + } + + return cw, nil +} + +func (pw *pemWatcher) CertPool() (*x509.CertPool, error) { + if err := pw.checkRefresh(); err != nil { + return nil, err + } + + return pw.certPool, nil +} + +func (pw *pemWatcher) checkRefresh() error { + if time.Since(pw.lastChecked) <= pw.opt.PoolingInterval { + return nil + } + + pw.Debugf("Checking if cert has changed...") + + newModTime, err := pw.getModTime() + if err != nil { + return err + } + + if pw.modTime.Equal(newModTime) { + pw.lastChecked = time.Now().UTC() + pw.Debugf("No change") + return nil + } + + if err := pw.refreshCertPool(); err != nil { + return err + } + + pw.modTime = newModTime + pw.lastChecked = time.Now().UTC() + + pw.Debugf("Cert refreshed") + + return nil +} + +func (pw *pemWatcher) getModTime() (time.Time, error) { + info, err := os.Stat(pw.opt.PemFilePath) + if err != nil { + return time.Time{}, err + } + + return info.ModTime().UTC(), nil +} + +func (pw *pemWatcher) refreshCertPool() error { + pemCert, err := os.ReadFile(pw.opt.PemFilePath) + if err != nil { + return nil + } + + pw.certPool = x509.NewCertPool() + pw.certPool.AppendCertsFromPEM(pemCert) + return nil +} + +func (pw *pemWatcher) Debugf(format string, v ...interface{}) { + if !pw.debug { + return + } + + pw.log.Debugf(format, v...) +} diff --git a/pem_watcher_test.go b/pem_watcher_test.go new file mode 100644 index 0000000..b1c2343 --- /dev/null +++ b/pem_watcher_test.go @@ -0,0 +1,250 @@ +// Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// resty source code and usage is governed by a MIT style +// license that can be found in the LICENSE file. + +package resty + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" +) + +type certPaths struct { + RootCAKey string + RootCACert string + TLSKey string + TLSCert string +} + +func TestClient_SetRootCertificateWatcher(t *testing.T) { + // For this test, we want to: + // - Generate root CA + // - Generate TLS cert signed with root CA + // - Start an HTTPS server + // - Create a resty client with SetRootCertificateWatcher + // - Send multiple request and re-generate the certs periodically to reproduce renewal + + certDir := t.TempDir() + paths := certPaths{ + RootCAKey: filepath.Join(certDir, "root-ca.key"), + RootCACert: filepath.Join(certDir, "root-ca.crt"), + TLSKey: filepath.Join(certDir, "tls.key"), + TLSCert: filepath.Join(certDir, "tls.crt"), + } + + port := findAvailablePort(t) + generateCerts(t, paths) + startHTTPSServer(fmt.Sprintf(":%d", port), paths) + + client := New().SetDebug(true).SetRootCertificateWatcher(&WatcherOptions{ + PemFilePath: paths.RootCACert, + PoolingInterval: time.Second * 1, + }) + + url := fmt.Sprintf("https://localhost:%d/", port) + + for i := 0; i < 5; i++ { + time.Sleep(1 * time.Second) + res, err := client.R().Get(url) + if err != nil { + t.Fatal(err) + } + + assertEqual(t, res.StatusCode(), http.StatusOK) + + if i%2 == 1 { + // Re-generate certs to simulate renewal scenario + generateCerts(t, paths) + } + } +} + +func startHTTPSServer(addr string, path certPaths) { + tlsConfig := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(path.TLSCert, path.TLSKey) + if err != nil { + return nil, err + } + return &cert, nil + }, + } + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + srv := &http.Server{ + Addr: addr, + TLSConfig: tlsConfig, + } + + go func() { + err := srv.ListenAndServeTLS("", "") + if err != nil { + panic(err) + } + }() +} + +func findAvailablePort(t *testing.T) int { + port := -1 + + for port == -1 { + listener, err := net.Listen("tcp", ":0") + if err != nil { + continue + } + port = listener.Addr().(*net.TCPAddr).Port + if err := listener.Close(); err != nil { + t.Fatal(err) + } + } + + return port +} + +func generateCerts(t *testing.T, paths certPaths) { + rootKey, rootCert, err := generateRootCA(paths.RootCAKey, paths.RootCACert) + if err != nil { + t.Fatal(err) + } + + if err := generateTLSCert(paths.TLSKey, paths.TLSCert, rootKey, rootCert); err != nil { + t.Fatal(err) + } +} + +// Generate a Root Certificate Authority (CA) +func generateRootCA(keyPath, certPath string) (*rsa.PrivateKey, []byte, error) { + // Generate the key for the Root CA + rootKey, err := generateKey() + if err != nil { + return nil, nil, err + } + + // Create the root certificate template + rootCertTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"YourOrg"}, + Country: []string{"US"}, + Province: []string{"State"}, + Locality: []string{"City"}, + CommonName: "YourRootCA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), // 10 years validity + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + IsCA: true, + BasicConstraintsValid: true, + } + + // Self-sign the root certificate + rootCert, err := x509.CreateCertificate(rand.Reader, rootCertTemplate, rootCertTemplate, &rootKey.PublicKey, rootKey) + if err != nil { + return nil, nil, err + } + + // Save the Root CA key and certificate + if err := savePEMKey(keyPath, rootKey); err != nil { + return nil, nil, err + } + if err := savePEMCert(certPath, rootCert); err != nil { + return nil, nil, err + } + + return rootKey, rootCert, nil +} + +// Generate a TLS Certificate signed by the Root CA +func generateTLSCert(keyPath, certPath string, rootKey *rsa.PrivateKey, rootCert []byte) error { + // Generate a key for the server + serverKey, err := generateKey() + if err != nil { + return err + } + + // Parse the Root CA certificate + parsedRootCert, err := x509.ParseCertificate(rootCert) + if err != nil { + return err + } + + // Create the server certificate template + serverCertTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + Organization: []string{"YourOrg"}, + CommonName: "localhost", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), // 1 year validity + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"localhost"}, + } + + // Sign the server certificate with the Root CA + serverCert, err := x509.CreateCertificate(rand.Reader, serverCertTemplate, parsedRootCert, &serverKey.PublicKey, rootKey) + if err != nil { + return err + } + + // Save the server key and certificate + if err := savePEMKey(keyPath, serverKey); err != nil { + return err + } + if err := savePEMCert(certPath, serverCert); err != nil { + return err + } + + return nil +} + +func generateKey() (*rsa.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, 2048) +} + +func savePEMKey(fileName string, key *rsa.PrivateKey) error { + keyFile, err := os.Create(fileName) + if err != nil { + return err + } + defer keyFile.Close() + + privateKeyPEM := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + } + + return pem.Encode(keyFile, privateKeyPEM) +} + +func savePEMCert(fileName string, cert []byte) error { + certFile, err := os.Create(fileName) + if err != nil { + return err + } + defer certFile.Close() + + certPEM := &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert, + } + + return pem.Encode(certFile, certPEM) +}