Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(portforwarding): Allow running script upon port forwarding success #2399

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions internal/configuration/settings/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ type PortForwarding struct {
// to write to a file. It cannot be nil for the
// internal state
Filepath *string `json:"status_file_path"`
// Scriptpath is the port forwarding status script path
// to use. It can be the empty string to indicate not
// to call a script. It cannot be nil for the
// internal state
Scriptpath *string `json:"status_script_path"`
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved
// ListeningPort is the port traffic would be redirected to from the
// forwarded port. The redirection is disabled if it is set to 0, which
// is its default as well.
Expand Down Expand Up @@ -66,6 +71,14 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
}
}

// Validate Scriptpath
if *p.Scriptpath != "" { // optional
_, err := filepath.Abs(*p.Scriptpath)
if err != nil {
return fmt.Errorf("scriptpath is not valid: %w", err)
}
}
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved

if providerSelected == providers.PrivateInternetAccess {
switch {
case p.Username == "":
Expand All @@ -83,6 +96,7 @@ func (p *PortForwarding) Copy() (copied PortForwarding) {
Enabled: gosettings.CopyPointer(p.Enabled),
Provider: gosettings.CopyPointer(p.Provider),
Filepath: gosettings.CopyPointer(p.Filepath),
Scriptpath: gosettings.CopyPointer(p.Scriptpath),
ListeningPort: gosettings.CopyPointer(p.ListeningPort),
Username: p.Username,
Password: p.Password,
Expand All @@ -93,6 +107,7 @@ func (p *PortForwarding) OverrideWith(other PortForwarding) {
p.Enabled = gosettings.OverrideWithPointer(p.Enabled, other.Enabled)
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)
p.Scriptpath = gosettings.OverrideWithPointer(p.Scriptpath, other.Scriptpath)
p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort)
p.Username = gosettings.OverrideWithComparable(p.Username, other.Username)
p.Password = gosettings.OverrideWithComparable(p.Password, other.Password)
Expand All @@ -102,6 +117,7 @@ func (p *PortForwarding) setDefaults() {
p.Enabled = gosettings.DefaultPointer(p.Enabled, false)
p.Provider = gosettings.DefaultPointer(p.Provider, "")
p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port")
p.Scriptpath = gosettings.DefaultPointer(p.Scriptpath, "")
p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0)
}

Expand Down Expand Up @@ -134,6 +150,12 @@ func (p PortForwarding) toLinesNode() (node *gotree.Node) {
}
node.Appendf("Forwarded port file path: %s", filepath)

scriptpath := *p.Scriptpath
if scriptpath == "" {
scriptpath = "[not set]"
}
node.Appendf("Forwarded port script path: %s", scriptpath)
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved

if p.Username != "" {
credentialsNode := node.Appendf("Credentials:")
credentialsNode.Appendf("Username: %s", p.Username)
Expand Down Expand Up @@ -162,6 +184,9 @@ func (p *PortForwarding) read(r *reader.Reader) (err error) {
"PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING_STATUS_FILE",
))

p.Scriptpath = r.Get("VPN_PORT_FORWARDING_STATUS_SCRIPT",
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved
reader.ForceLowercase(false))
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved

p.ListeningPort, err = r.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT")
if err != nil {
return err
Expand Down
7 changes: 5 additions & 2 deletions internal/portforward/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ type Loop struct {

func NewLoop(settings settings.PortForwarding, routing Routing,
client *http.Client, portAllower PortAllower,
logger Logger, uid, gid int) *Loop {
logger Logger, uid, gid int,
) *Loop {
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved
return &Loop{
settings: Settings{
VPNIsUp: ptrTo(false),
Service: service.Settings{
Enabled: settings.Enabled,
Filepath: *settings.Filepath,
Scriptpath: *settings.Scriptpath,
ListeningPort: *settings.ListeningPort,
},
},
Expand Down Expand Up @@ -75,7 +77,8 @@ func (l *Loop) Start(_ context.Context) (runError <-chan error, _ error) {

func (l *Loop) run(runCtx context.Context, runDone chan<- struct{},
runErrorCh chan<- error, updateTrigger <-chan Settings,
updateResult chan<- error) {
updateResult chan<- error,
) {
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved
defer close(runDone)

var serviceRunError <-chan error
Expand Down
29 changes: 29 additions & 0 deletions internal/portforward/service/script.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package service

import (
"fmt"
"os"
"os/exec"
"strings"
)

func (s *Service) runPortForwardedScript(ports []uint16) (err error) {
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved
// run bash script with ports as arguments
portStrings := make([]string, len(ports))
for i, port := range ports {
portStrings[i] = fmt.Sprint(int(port))
}
portsString := strings.Join(portStrings, " ")
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved

scriptPath := s.settings.Scriptpath
s.logger.Info("running port forward script " + scriptPath)
cmd := exec.Command(scriptPath, portsString)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err = cmd.Run()
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return fmt.Errorf("running script: %w", err)
}

return nil
}
3 changes: 3 additions & 0 deletions internal/portforward/service/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type Settings struct {
Enabled *bool
PortForwarder PortForwarder
Filepath string
Scriptpath string
Interface string // needed for PIA and ProtonVPN, tun0 for example
ServerName string // needed for PIA
CanPortForward bool // needed for PIA
Expand All @@ -24,6 +25,7 @@ func (s Settings) Copy() (copied Settings) {
copied.Enabled = gosettings.CopyPointer(s.Enabled)
copied.PortForwarder = s.PortForwarder
copied.Filepath = s.Filepath
copied.Scriptpath = s.Scriptpath
copied.Interface = s.Interface
copied.ServerName = s.ServerName
copied.CanPortForward = s.CanPortForward
Expand All @@ -37,6 +39,7 @@ func (s *Settings) OverrideWith(update Settings) {
s.Enabled = gosettings.OverrideWithPointer(s.Enabled, update.Enabled)
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
s.Scriptpath = gosettings.OverrideWithComparable(s.Scriptpath, update.Scriptpath)
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
Expand Down
11 changes: 10 additions & 1 deletion internal/portforward/service/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
return nil, fmt.Errorf("writing port file: %w", err)
}

if s.settings.Scriptpath == "" {
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved
err = s.runPortForwardedScript(ports)
if err != nil {
_ = s.cleanup()
return nil, fmt.Errorf("running port forward script: %w", err)
}
}

s.portMutex.Lock()
s.ports = ports
s.portMutex.Unlock()
Expand All @@ -80,7 +88,8 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
s.keepPortDoneCh = keepPortDoneCh

go func(ctx context.Context, portForwarder PortForwarder,
obj utils.PortForwardObjects, runError chan<- error, doneCh chan<- struct{}) {
obj utils.PortForwardObjects, runError chan<- error, doneCh chan<- struct{},
) {
lavalleeale marked this conversation as resolved.
Show resolved Hide resolved
defer close(doneCh)
err = portForwarder.KeepPortForward(ctx, obj)
crashed := ctx.Err() == nil
Expand Down