Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implemented ability to dynamically reload TLS certificates #866

Draft
wants to merge 1 commit into
base: v2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,15 @@ func (c *Client) SetRootCertificate(pemFilePath string) *Client {
return c
}

// SetRootCertificateWatcher enables dynamic reloading of one or more root certificates.
// It is designed for scenarios involving long-running Resty clients where certificates may be renewed.
//
// client.SetRootCertificateWatcher(&WatcherOptions{PemFilePath: "root-ca.crt"})
func (c *Client) SetRootCertificateWatcher(options *WatcherOptions) *Client {
c.handleCAsWatcher("root", options)
return c
}

// SetRootCertificateFromString method helps to add one or more root certificates
// into the Resty client
//
Expand All @@ -888,6 +897,15 @@ func (c *Client) SetClientRootCertificate(pemFilePath string) *Client {
return c
}

// SetClientRootCertificateWatcher enables dynamic reloading of one or more root certificates.
// It is designed for scenarios involving long-running Resty clients where certificates may be renewed.
//
// client.SetClientRootCertificateWatcher(&WatcherOptions{PemFilePath: "root-ca.crt"})
func (c *Client) SetClientRootCertificateWatcher(options *WatcherOptions) *Client {
c.handleCAsWatcher("client", options)
return c
}

// SetClientRootCertificateFromString method helps to add one or more clients
// root certificates into the Resty client
//
Expand Down Expand Up @@ -918,6 +936,36 @@ func (c *Client) handleCAs(scope string, permCerts []byte) {
}
}

func (c *Client) handleCAsWatcher(scope string, options *WatcherOptions) {
pw, err := newPemWatcher(options, c.log, c.Debug)
if err != nil {
c.log.Errorf("%v", err)
return
}

tlsConfig, err := c.tlsConfig()
if err != nil {
c.log.Errorf("%v", err)
return
}

c.OnBeforeRequest(func(client *Client, request *Request) error {
certPool, err := pw.CertPool()
if err != nil {
return err
}

switch scope {
case "root":
tlsConfig.RootCAs = certPool
case "client":
tlsConfig.ClientCAs = certPool
}

return nil
})
}

// SetOutputDirectory method sets the output directory for saving HTTP responses in a file.
// Resty creates one if the output directory does not exist. This setting is optional,
// if you plan to use the absolute path in [Request.SetOutput] and can used together.
Expand Down
124 changes: 124 additions & 0 deletions pem_watcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (c) 2015-2024 Jeevanandam M ([email protected]), All rights reserved.
// resty source code and usage is governed by a MIT style
// license that can be found in the LICENSE file.

package resty

import (
"crypto/x509"
"errors"
"os"
"time"
)

const (
defaultWatcherPoolingInterval = 1 * time.Minute
)

// WatcherOptions struct is used to enable TLS Certificate hot reloading.
type WatcherOptions struct {
// PemFilePath is the path of the PEM file
PemFilePath string

// PoolingInterval is the frequency at which resty will check if the PEM file needs to be reloaded.
// Default is 1 min.
PoolingInterval time.Duration
}

type pemWatcher struct {
opt *WatcherOptions

certPool *x509.CertPool
modTime time.Time
lastChecked time.Time
log Logger
debug bool
}

func newPemWatcher(options *WatcherOptions, log Logger, debug bool) (*pemWatcher, error) {
if options.PemFilePath == "" {
return nil, errors.New("PemFilePath is required")
}

if options.PoolingInterval == 0 {
options.PoolingInterval = defaultWatcherPoolingInterval
}

cw := &pemWatcher{
opt: options,
log: log,
debug: debug,
}

if err := cw.checkRefresh(); err != nil {
return nil, err
}

return cw, nil
}

func (pw *pemWatcher) CertPool() (*x509.CertPool, error) {
if err := pw.checkRefresh(); err != nil {
return nil, err
}

return pw.certPool, nil
}

func (pw *pemWatcher) checkRefresh() error {
if time.Since(pw.lastChecked) <= pw.opt.PoolingInterval {
return nil
}

pw.Debugf("Checking if cert has changed...")

newModTime, err := pw.getModTime()
if err != nil {
return err
}

if pw.modTime.Equal(newModTime) {
pw.lastChecked = time.Now().UTC()
pw.Debugf("No change")
return nil
}

if err := pw.refreshCertPool(); err != nil {
return err
}

pw.modTime = newModTime
pw.lastChecked = time.Now().UTC()

pw.Debugf("Cert refreshed")

return nil
}

func (pw *pemWatcher) getModTime() (time.Time, error) {
info, err := os.Stat(pw.opt.PemFilePath)
if err != nil {
return time.Time{}, err
}

return info.ModTime().UTC(), nil
}

func (pw *pemWatcher) refreshCertPool() error {
pemCert, err := os.ReadFile(pw.opt.PemFilePath)
if err != nil {
return nil
}

pw.certPool = x509.NewCertPool()
pw.certPool.AppendCertsFromPEM(pemCert)
return nil
}

func (pw *pemWatcher) Debugf(format string, v ...interface{}) {
if !pw.debug {
return
}

pw.log.Debugf(format, v...)
}
Loading