Skip to content

Commit

Permalink
allow different ssh ports
Browse files Browse the repository at this point in the history
  • Loading branch information
Deleh committed Mar 13, 2024
1 parent f078982 commit 85cd437
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions robmuxinator/robmuxinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,15 @@ def format(self, record):
class SSHClient:
"""Handle commands over ssh tunnel"""

def __init__(self, user, hostname):
def __init__(self, user, hostname, port=DEFAULT_PORT):
self._user = user
self._hostname = hostname

if port is not None:
self._port = port
else:
self._port = 22

# check if user has sudo privileges
self._sudo_user = True if os.getuid() == 0 else False

Expand All @@ -151,6 +156,7 @@ def init_connection(self):
self.ssh_cli.connect(
username=self._user,
hostname=self._hostname,
port=self._port,
key_filename=key_filename,
disabled_algorithms={"pubkeys": ["rsa-sha2-256", "rsa-sha2-512"]},
)
Expand Down Expand Up @@ -259,6 +265,9 @@ def __init__(self, hostname, user, port=DEFAULT_PORT):
def get_hostname(self):
return self._hostname

def get_port(self):
return self._port

def shutdown(self, timeout=30):
pass

Expand Down Expand Up @@ -312,7 +321,7 @@ class LinuxHost(Host):

def __init__(self, hostname, user, port=DEFAULT_PORT, check_nfs=True):
super().__init__(hostname, user, port)
self._ssh_client = SSHClient(user, hostname)
self._ssh_client = SSHClient(user, hostname, port)
self._check_nfs = check_nfs

def shutdown(self, timeout=60):
Expand Down Expand Up @@ -406,6 +415,7 @@ def shutdown(self, timeout=60):
logger.info(" {} is down".format(self._hostname))
return True


class OnlineHost(Host):
"""Handle hosts which should only be available on the network"""

Expand All @@ -414,6 +424,8 @@ def __init__(self, hostname, port=DEFAULT_PORT):

def shutdown(self, timeout=60):
return True


class Session(object):
def __init__(self, ssh_client, session_name, yaml, envs=None) -> None:
self._session_name = session_name
Expand Down Expand Up @@ -531,6 +543,7 @@ def dump(self):
print("\tprio: {}".format(self.prio))
print("\tlocked: {}".format(self._locked))


def wait_for_hosts(hosts, timeout=30):
start = datetime.now()
logger.info("==================================")
Expand Down Expand Up @@ -844,7 +857,7 @@ def main():
if key in args.sessions:
sessions.append(
Session(
SSHClient(user=user, hostname=hosts[host].get_hostname()),
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_port()),
key,
yaml_sessions[key],
envs
Expand All @@ -853,7 +866,7 @@ def main():
else:
sessions.append(
Session(
SSHClient(user=user, hostname=hosts[host].get_hostname()),
SSHClient(user=user, hostname=hosts[host].get_hostname(), port=hosts[host].get_port()),
key,
yaml_sessions[key],
envs
Expand Down

0 comments on commit 85cd437

Please sign in to comment.