diff --git a/pysqa/ext/remote.py b/pysqa/ext/remote.py index ea94619..d36f4f1 100644 --- a/pysqa/ext/remote.py +++ b/pysqa/ext/remote.py @@ -1,6 +1,7 @@ # coding: utf-8 # Copyright (c) Jan Janssen +import getpass import json import os import warnings @@ -25,12 +26,18 @@ def __init__(self, config, directory="~/.queues", execute_command=execute_comman ) if "ssh_key" in config.keys(): self._ssh_key = os.path.abspath(os.path.expanduser(config["ssh_key"])) + self._ssh_ask_for_password = False else: self._ssh_key = None if "ssh_password" in config.keys(): self._ssh_password = config["ssh_password"] + self._ssh_ask_for_password = False else: self._ssh_password = None + if "ssh_ask_for_password" in config.keys(): + self._ssh_ask_for_password = config["ssh_ask_for_password"] + else: + self._ssh_ask_for_password = False if "ssh_key_passphrase" in config.keys(): self._ssh_key_passphrase = config["ssh_key_passphrase"] else: @@ -221,7 +228,11 @@ def _transfer_files(self, file_dict, sftp=None, transfer_back=False): def _open_ssh_connection(self): ssh = paramiko.SSHClient() ssh.load_host_keys(self._ssh_known_hosts) - if self._ssh_key is not None and self._ssh_key_passphrase is not None: + if ( + self._ssh_key is not None + and self._ssh_key_passphrase is not None + and not self._ssh_ask_for_password + ): ssh.connect( hostname=self._ssh_host, port=self._ssh_port, @@ -229,7 +240,7 @@ def _open_ssh_connection(self): key_filename=self._ssh_key, passphrase=self._ssh_key_passphrase, ) - elif self._ssh_key is not None: + elif self._ssh_key is not None and not self._ssh_ask_for_password: ssh.connect( hostname=self._ssh_host, port=self._ssh_port, @@ -240,6 +251,7 @@ def _open_ssh_connection(self): self._ssh_password is not None and self._ssh_authenticator_service is None and not self._ssh_two_factor_authentication + and not self._ssh_ask_for_password ): ssh.connect( hostname=self._ssh_host, @@ -247,6 +259,13 @@ def _open_ssh_connection(self): username=self._ssh_username, password=self._ssh_password, ) + elif self._ssh_ask_for_password and not self._ssh_two_factor_authentication: + ssh.connect( + hostname=self._ssh_host, + port=self._ssh_port, + username=self._ssh_username, + password=getpass.getpass(prompt="SSH Password: ", stream=None), + ) elif ( self._ssh_password is not None and self._ssh_authenticator_service is not None @@ -287,6 +306,16 @@ def authentication(title, instructions, prompt_list): ssh._transport.auth_interactive_dumb( username=self._ssh_username, handler=None, submethods="" ) + elif self._ssh_ask_for_password and self._ssh_two_factor_authentication: + ssh.connect( + hostname=self._ssh_host, + port=self._ssh_port, + username=self._ssh_username, + password=getpass.getpass(prompt="SSH Password: ", stream=None), + ) + ssh._transport.auth_interactive_dumb( + username=self._ssh_username, handler=None, submethods="" + ) else: raise ValueError("Un-supported authentication method.")