diff --git a/src/amuse/config.py b/src/amuse/config.py index 51d0f076b1..e60c45b346 100644 --- a/src/amuse/config.py +++ b/src/amuse/config.py @@ -6,27 +6,28 @@ import os import warnings - -def parse_configmk(filename): - configfile = open(filename, "r") - lines = configfile.readlines() - configfile.close() +def parse_configmk_lines(lines,label): cfgvars = dict() if "amuse configuration" not in lines[0]: raise Exception( - "file: {0} is not an amuse configuration file".format(filename) + f"{label} is not an amuse configuration file" ) for line in lines: if "=" in line: var, value = line.split("=", 1) if value.startswith("@") and value.endswith("@"): warnings.warn( - "possible configuration error/ unconfigured variable in" - " {0}".format(filename) + f"possible configuration error/ unconfigured variable in" + f" {label}" ) cfgvars[var] = value.strip() return cfgvars +def parse_configmk(filename): + configfile = open(filename, "r") + lines = configfile.readlines() + configfile.close() + return parse_configmk_lines(lines, "file " + filename) try: configmk = parse_configmk("config.mk") diff --git a/src/amuse/rfi/channel.py b/src/amuse/rfi/channel.py index 64decbe682..09fcebc3b2 100644 --- a/src/amuse/rfi/channel.py +++ b/src/amuse/rfi/channel.py @@ -35,8 +35,9 @@ from amuse.support.options import OptionalAttributes, option, GlobalOptions from amuse.support.core import late from amuse.support import exceptions -from amuse.support import get_amuse_root_dir +from amuse.support import get_amuse_root_dir, get_amuse_package_dir from amuse.rfi import run_command_redirected +from amuse.config import parse_configmk_lines from amuse.rfi import slurm @@ -509,9 +510,10 @@ def XTERM(cls, full_name_of_the_worker, channel, interpreter_executable=None, im @classmethod - def REDIRECT(cls, full_name_of_the_worker, stdoutname, stderrname, command=None, interpreter_executable=None, **options): + def REDIRECT(cls, full_name_of_the_worker, stdoutname, stderrname, command=None, + interpreter_executable=None, run_command_redirected_file=None ): - fname = run_command_redirected.__file__ + fname = run_command_redirected_file or run_command_redirected.__file__ arguments = [fname , stdoutname, stderrname] if not interpreter_executable is None: @@ -806,6 +808,8 @@ def split_message(self, call_id, function_id, call_count, dtype_to_arguments, en self._communicated_splitted_message = True self._merged_results_splitted_message = dtype_to_result + def makedirs(self,directory): + os.makedirs(directory) AbstractMessageChannel.DEBUGGERS = { @@ -1247,17 +1251,11 @@ def __setstate__(self, state): self._is_inuse = False self._communicated_splitted_message = False self.inuse_semaphore = threading.Semaphore() - - - @option(sections=("channel",)) def job_scheduler(self): """Name of the job scheduler to use when starting the code, if given will use job scheduler to find list of hostnames for spawning""" return "" - - - def get_info_from_job_scheduler(self, name, number_of_workers = 1): if name == "slurm": @@ -1277,7 +1275,7 @@ def get_info_from_slurm(cls, number_of_workers): for _ in range(tasks): all_nodes.append(node) cls._scheduler_nodes = all_nodes - cls._scheduler_index = 1 # start at 1 assumes that the python script is running on the first node as the first task + cls._scheduler_index = 1 # start at 1 assumes that the python script is running on the first node as the first task cls._scheduler_initialized = True print("NODES:", cls._scheduler_nodes) hostnames = [] @@ -1291,7 +1289,7 @@ def get_info_from_slurm(cls, number_of_workers): host = ','.join(hostnames) print("HOST:", host, cls._scheduler_index, os.environ['SLURM_TASKS_PER_NODE']) info = MPI.Info.Create() - info['host'] = host #actually in mpich and openmpi, the host parameter is interpreted as a comma separated list of host names, + info['host'] = host # actually in mpich and openmpi, the host parameter is interpreted as a comma separated list of host names, return info @@ -1473,21 +1471,17 @@ def _extra_path_item(self, path_of_the_module): if len(x) > len(result): result = x return result - - - @option(choices=AbstractMessageChannel.DEBUGGERS.keys(), sections=("channel",)) def debugger(self): """Name of the debugger to use when starting the code""" return "none" - - @option(type="boolean") def check_mpi(self): return True + class SocketMessage(AbstractMessage): def _receive_all(self, nbytes, thesocket): @@ -1679,7 +1673,6 @@ def send(self, socket): self.send_doubles(socket, self.encoded_units) # logger.debug("message send") - def send_doubles(self, socket, array): if len(array) > 0: @@ -1717,10 +1710,12 @@ def send_longs(self, socket, array): if len(array) > 0: data_buffer = numpy.array(array, dtype='int64') socket.sendall(data_buffer.tobytes()) - + + class SocketChannel(AbstractMessageChannel): - def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, **options): + def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_executable=None, + remote_env=None, **options): AbstractMessageChannel.__init__(self, **options) #logging.getLogger().setLevel(logging.DEBUG) @@ -1731,10 +1726,16 @@ def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_e self.name_of_the_worker = name_of_the_worker self.interpreter_executable = interpreter_executable - - if self.hostname != None and self.hostname not in ['localhost',socket.gethostname()]: - raise exceptions.CodeException("can only run codes on local machine using SocketChannel, not on %s", self.hostname) - + + if self.hostname == None: + self.hostname="localhost" + + if self.hostname not in ['localhost',socket.gethostname()]: + self.remote=True + self.must_check_if_worker_is_up_to_date=False + else: + self.remote=False + self.id = 0 if not legacy_interface_type is None: @@ -1748,7 +1749,7 @@ def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_e self._communicated_splitted_message = False self.socket = None - + self.remote_env=remote_env @option(sections=("channel",)) def mpiexec(self): @@ -1761,10 +1762,6 @@ def mpiexec(self): def mpiexec_number_of_workers_flag(self): """flag to use, so that the number of workers are defined""" return '-n' - - - - @late def debugger_method(self): @@ -1787,28 +1784,9 @@ def accept_worker_connection(self, server_socket, process): raise exceptions.CodeException('worker still not started after 60 seconds') - - - def start(self): - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - server_socket.bind(('', 0)) - server_socket.settimeout(1.0) - server_socket.listen(1) - - logger.debug("starting socket worker process, listening for worker connection on %s", server_socket.getsockname()) - - #this option set by CodeInterface - logger.debug("mpi_enabled: %s", str(self.initialize_mpi)) - - # set arguments to name of the worker, and port number we listen on - - - self.stdout = None - self.stderr = None - + def generate_command_and_arguments(self,server_address,port): arguments = [] - + if not self.debugger_method is None: command, arguments = self.debugger_method(self.full_name_of_the_worker, self, interpreter_executable=self.interpreter_executable) else: @@ -1835,25 +1813,152 @@ def start(self): command = mpiexec[0] #append with port and hostname where the worker should connect - arguments.append(str(server_socket.getsockname()[1])) + arguments.append(port) #hostname of this machine - arguments.append(str(socket.gethostname())) - + arguments.append(server_address) + #initialize MPI inside worker executable arguments.append('true') else: #append arguments with port and socket where the worker should connect - arguments.append(str(server_socket.getsockname()[1])) + arguments.append(port) #local machine - arguments.append('localhost') + arguments.append(server_address) #do not initialize MPI inside worker executable arguments.append('false') + + return command,arguments + + def remote_env_string(self, hostname): + if self.remote_env is None: + if hostname in self.remote_envs.keys(): + return "source "+self.remote_envs[hostname]+"\n" + else: + return "" + else: + return "source "+self.remote_env +"\n" + + def generate_remote_command_and_arguments(self,hostname, server_address,port): + + # get remote config + args=["ssh","-T", hostname] + + command=self.remote_env_string(self.hostname)+ \ + "amusifier --get-amuse-config" +"\n" + + proc=Popen(args,stdout=PIPE, stdin=PIPE, executable="ssh") + out,err=proc.communicate(command.encode()) + + try: + remote_config=parse_configmk_lines(out.decode().split("\n"),"remote config at "+self.hostname ) + except: + raise Exception(f"failed getting remote config from {self.hostname} - please check remote_env argument ({self.remote_env})") + + # get remote amuse package dir + command=self.remote_env_string(self.hostname)+ \ + "amusifier --get-amuse-package-dir" +"\n" + + proc=Popen(args,stdout=PIPE, stdin=PIPE, executable="ssh") + out,err=proc.communicate(command.encode()) + + remote_package_dir=out.decode().strip(" \n\t") + local_package_dir=get_amuse_package_dir() + + mpiexec=remote_config["MPIEXEC"] + initialize_mpi=remote_config["MPI_ENABLED"] == 'yes' + run_command_redirected_file=run_command_redirected.__file__.replace(local_package_dir,remote_package_dir) + interpreter_executable=None if self.interpreter_executable==None else remote_config["PYTHON"] + # dynamic python workers? (should be send over) + full_name_of_the_worker=self.full_name_of_the_worker.replace(local_package_dir,remote_package_dir) + python_exe_for_redirection=remote_config["PYTHON"] + + if not self.debugger_method is None: + raise Exception("remote socket channel debugging not yet supported") + #command, arguments = self.debugger_method(self.full_name_of_the_worker, self, interpreter_executable=self.interpreter_executable) + else: + if self.redirect_stdout_file == 'none' and self.redirect_stderr_file == 'none': + + if interpreter_executable is None: + command = full_name_of_the_worker + arguments = [] + else: + command = interpreter_executable + arguments = [full_name_of_the_worker] + else: + command, arguments = self.REDIRECT(full_name_of_the_worker, self.redirect_stdout_file, + self.redirect_stderr_file, command=python_exe_for_redirection, + interpreter_executable=interpreter_executable, + run_command_redirected_file=run_command_redirected_file) + + #start arguments with command + arguments.insert(0, command) + + if initialize_mpi and len(mpiexec) > 0: + mpiexec = shlex.split(mpiexec) + # prepend with mpiexec and arguments back to front + arguments.insert(0, str(self.number_of_workers)) + arguments.insert(0, self.mpiexec_number_of_workers_flag) + arguments[:0] = mpiexec + command = mpiexec[0] + + #append with port and hostname where the worker should connect + arguments.append(port) + #hostname of this machine + arguments.append(server_address) - logger.debug("starting process with command `%s`, arguments `%s` and environment '%s'", command, arguments, os.environ) - self.process = Popen(arguments, executable=command, stdin=PIPE, stdout=None, stderr=None, close_fds=self.close_fds) - logger.debug("waiting for connection from worker") + #initialize MPI inside worker executable + arguments.append('true') + else: + #append arguments with port and socket where the worker should connect + arguments.append(port) + #local machine + arguments.append(server_address) + + #do not initialize MPI inside worker executable + arguments.append('false') + + return command,arguments + def start(self): + + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + server_address=self.get_host_ip(self.hostname) + + server_socket.bind((server_address , 0)) + server_socket.settimeout(1.0) + server_socket.listen(1) + + logger.debug("starting socket worker process, listening for worker connection on %s", server_socket.getsockname()) + + #this option set by CodeInterface + logger.debug("mpi_enabled: %s", str(self.initialize_mpi)) + + # set arguments to name of the worker, and port number we listen on + + self.stdout = None + self.stderr = None + + if self.remote: + command,arguments=self.generate_remote_command_and_arguments(self.hostname,server_address,str(server_socket.getsockname()[1])) + else: + command,arguments=self.generate_command_and_arguments(server_address,str(server_socket.getsockname()[1])) + + if self.remote: + logger.debug("starting remote process on %s with command `%s`, arguments `%s` and environment '%s'", self.hostname, command, arguments, os.environ) + ssh_command=self.remote_env_string(self.hostname)+" ".join(arguments) + arguments=["ssh","-T", self.hostname] + command="ssh" + self.process = Popen(arguments, executable=command, stdin=PIPE, stdout=None, stderr=None, close_fds=self.close_fds) + self.process.stdin.write(ssh_command.encode()) + self.process.stdin.close() + else: + logger.debug("starting process with command `%s`, arguments `%s` and environment '%s'", command, arguments, os.environ) + # ~ print(arguments) + self.process = Popen(arguments, executable=command, stdin=PIPE, stdout=None, stderr=None, close_fds=self.close_fds) + + logger.debug("waiting for connection from worker") self.socket, address = self.accept_worker_connection(server_socket, self.process) self.socket.setblocking(1) @@ -1870,7 +1975,12 @@ def start(self): @option(type="boolean", sections=("sockets_channel",)) def close_fds(self): """close open file descriptors when spawning child process""" - return True + return False + + @option(type="dict", sections=("sockets_channel",)) + def remote_envs(self): + """ dict of remote machine - enviroment (source ..) pairs """ + return dict() @option(choices=AbstractMessageChannel.DEBUGGERS.keys(), sections=("channel",)) def debugger(self): @@ -1889,6 +1999,9 @@ def stop(self): self.socket.close() self.socket = None + + if not self.process.stdin is None: + self.process.stdin.close() # should lookinto using poll with a timeout or some other mechanism # when debugger method is on, no killing @@ -1952,8 +2065,6 @@ def send_message(self, call_id, function_id, dtype_to_arguments={}, encoded_unit message.send(self.socket) self._is_inuse = True - - def recv_message(self, call_id, function_id, handle_as_array, has_units=False): @@ -1965,7 +2076,6 @@ def recv_message(self, call_id, function_id, handle_as_array, has_units=False): del self._merged_results_splitted_message return x - message = SocketMessage() message.receive(self.socket) @@ -1988,8 +2098,6 @@ def recv_message(self, call_id, function_id, handle_as_array, has_units=False): return message.to_result(handle_as_array), message.encoded_units else: return message.to_result(handle_as_array) - - def nonblocking_recv_message(self, call_id, function_id, handle_as_array, has_units=False): request = SocketMessage().nonblocking_receive(self.socket) @@ -2031,6 +2139,27 @@ def max_message_length(self): """ return 1000000 + def sanitize_host(self,hostname): + if "@" in hostname: + return hostname.split("@")[1] + return hostname + + def get_host_ip(self, client): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect((self.sanitize_host(client), 80)) + ip=s.getsockname()[0] + s.close() + return ip + + def makedirs(self,directory): + if self.remote: + args=["ssh","-T", self.hostname] + command=f"mkdir -p {directory}\n" + proc=Popen(args,stdout=PIPE, stdin=PIPE, executable="ssh") + out,err=proc.communicate(command.encode()) + else: + os.makedirs(directory) + class OutputHandler(threading.Thread): @@ -2078,6 +2207,7 @@ def run(self): self.stream.write(data) + class DistributedChannel(AbstractMessageChannel): default_distributed_instance = None @@ -2161,8 +2291,6 @@ def __init__(self, name_of_the_worker, legacy_interface_type=None, interpreter_e logger.debug("worker dir is %s", self.worker_dir) self._is_inuse = False - - def check_if_worker_is_up_to_date(self, object): # if self.hostname != 'localhost': diff --git a/src/amuse/rfi/core.py b/src/amuse/rfi/core.py index 04b8645b11..fcd81d48e6 100644 --- a/src/amuse/rfi/core.py +++ b/src/amuse/rfi/core.py @@ -960,7 +960,9 @@ def channel_factory(self): if self.channel_type == 'mpi': if MpiChannel.is_supported(): return MpiChannel - else: + else: + warnings.warn("MPI (unexpectedly?) not available, falling back to sockets channel") + self.channel_type="sockets" return SocketChannel elif self.channel_type == 'remote': @@ -1064,15 +1066,17 @@ class CodeWithDataDirectories(object): def __init__(self): - if not self.channel_type == 'distributed': + if self.channel_type == 'distributed': + warnings.warn("Code with DistributedChannel wants to make output directory, check..") + else: self.ensure_data_directory_exists(self.get_output_directory()) def ensure_data_directory_exists(self, directory): directory = os.path.expanduser(directory) directory = os.path.expandvars(directory) - + try: - os.makedirs(directory) + self.channel.makedirs(directory) except OSError as ex: if ex.errno == errno.EEXIST and os.path.isdir(directory): pass diff --git a/src/amuse/rfi/gencode.py b/src/amuse/rfi/gencode.py index 0e222f9857..ee5d930bd0 100755 --- a/src/amuse/rfi/gencode.py +++ b/src/amuse/rfi/gencode.py @@ -25,7 +25,7 @@ from amuse.rfi.tools import create_dir from amuse.rfi.tools import create_python_worker -from amuse.support import get_amuse_root_dir +from amuse.support import get_amuse_root_dir, get_amuse_package_dir from amuse.support.literature import TrackLiteratureReferences def get_amuse_directory(): @@ -46,14 +46,6 @@ def get_amuse_directory(): #~ else: #~ return os.path.abspath(directory_of_this_script) -def get_amuse_directory_root(): - filename_of_this_script = __file__ - directory_of_this_script = os.path.dirname(os.path.dirname(filename_of_this_script)) - if os.path.isabs(directory_of_this_script): - return directory_of_this_script - else: - return os.path.abspath(directory_of_this_script) - def setup_sys_path(): amuse_directory = os.environ["AMUSE_DIR"] sys.path.insert(0, amuse_directory) @@ -171,6 +163,12 @@ def __init__(self): default=False, dest="get_amuse_dir", help="Only output amuse directory") + self.parser.add_option( + "--get-amuse-package-dir", + action="store_true", + default=False, + dest="get_amuse_package_dir", + help="Only output the amuse package root directory") self.parser.add_option( "--get-amuse-configmk", action="store_true", @@ -201,7 +199,7 @@ def parse_options(self): def parse_arguments(self): - if self.options.get_amuse_dir or self.options.get_amuse_configmk: + if self.options.get_amuse_dir or self.options.get_amuse_package_dir or self.options.get_amuse_configmk: return if self.options.mode == 'dir': if len(self.arguments) != 1: @@ -385,6 +383,9 @@ def amusifier(): if uc.options.get_amuse_dir: print(get_amuse_root_dir()) exit(0) + elif uc.options.get_amuse_package_dir: + print(get_amuse_package_dir()) + exit(0) elif uc.options.get_amuse_configmk: with open(os.path.join(get_amuse_root_dir(), "config.mk")) as f: print(f.read()) diff --git a/src/amuse/support/__init__.py b/src/amuse/support/__init__.py index 913bfde447..11fa96d2a1 100644 --- a/src/amuse/support/__init__.py +++ b/src/amuse/support/__init__.py @@ -26,3 +26,11 @@ def get_amuse_root_dir(): def get_amuse_data_dir(): return _Defaults().amuse_root_dir + +def get_amuse_package_dir(): + filename_of_this_script = __file__ + directory_of_this_script = os.path.dirname(os.path.dirname(filename_of_this_script)) + if os.path.isabs(directory_of_this_script): + return directory_of_this_script + else: + return os.path.abspath(directory_of_this_script)