-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #206 from pyiron/executor
Implement an concurrent.futures.Executor for pysqa
- Loading branch information
Showing
9 changed files
with
397 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,5 @@ dependencies: | |
- jinja2 =3.1.2 | ||
- paramiko =3.2.0 | ||
- tqdm =4.65.0 | ||
- pympipool =0.5.4 | ||
- cloudpickle =2.2.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import sys | ||
from pysqa.executor.backend import command_line | ||
|
||
|
||
if __name__ == "__main__": | ||
command_line(arguments_lst=sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
import sys | ||
|
||
from pympipool import PoolExecutor | ||
from pysqa.executor.helper import ( | ||
read_from_file, | ||
deserialize, | ||
write_to_file, | ||
serialize_result, | ||
) | ||
|
||
|
||
def execute_files_from_list(tasks_in_progress_dict, cache_directory, executor): | ||
file_lst = os.listdir(cache_directory) | ||
for file_name_in in file_lst: | ||
key = file_name_in.split(".in.pl")[0] | ||
file_name_out = key + ".out.pl" | ||
if ( | ||
file_name_in.endswith(".in.pl") | ||
and file_name_out not in file_lst | ||
and key not in tasks_in_progress_dict.keys() | ||
): | ||
funct_dict = read_from_file( | ||
file_name=os.path.join(cache_directory, file_name_in) | ||
) | ||
apply_dict = deserialize(funct_dict=funct_dict) | ||
for k, v in apply_dict.items(): | ||
tasks_in_progress_dict[k] = executor.submit( | ||
v["fn"], *v["args"], **v["kwargs"] | ||
) | ||
for k, v in tasks_in_progress_dict.items(): | ||
if v.done(): | ||
write_to_file( | ||
funct_dict=serialize_result(result_dict={k: v.result()}), | ||
state="out", | ||
cache_directory=cache_directory, | ||
) | ||
|
||
|
||
def execute_tasks(cores, cache_directory): | ||
tasks_in_progress_dict = {} | ||
with PoolExecutor( | ||
max_workers=cores, | ||
oversubscribe=False, | ||
enable_flux_backend=False, | ||
enable_slurm_backend=False, | ||
cwd=cache_directory, | ||
sleep_interval=0.1, | ||
queue_adapter=None, | ||
queue_adapter_kwargs=None, | ||
) as exe: | ||
while True: | ||
execute_files_from_list( | ||
tasks_in_progress_dict=tasks_in_progress_dict, | ||
cache_directory=cache_directory, | ||
executor=exe, | ||
) | ||
|
||
|
||
def command_line(arguments_lst=None): | ||
if arguments_lst is None: | ||
arguments_lst = sys.argv[1:] | ||
cores_arg = arguments_lst[arguments_lst.index("--cores") + 1] | ||
path_arg = arguments_lst[arguments_lst.index("--path") + 1] | ||
execute_tasks(cores=cores_arg, cache_directory=path_arg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import os | ||
import queue | ||
from threading import Thread | ||
from concurrent.futures import Future, Executor as FutureExecutor | ||
|
||
from pympipool import cancel_items_in_queue | ||
from pysqa.executor.helper import ( | ||
reload_previous_futures, | ||
find_executed_tasks, | ||
serialize_funct, | ||
write_to_file, | ||
) | ||
|
||
|
||
class Executor(FutureExecutor): | ||
def __init__(self, cwd=None, queue_adapter=None, queue_adapter_kwargs=None): | ||
self._task_queue = queue.Queue() | ||
self._memory_dict = {} | ||
self._cache_directory = os.path.abspath(os.path.expanduser(cwd)) | ||
self._queue_adapter = queue_adapter | ||
reload_previous_futures( | ||
future_queue=self._task_queue, | ||
future_dict=self._memory_dict, | ||
cache_directory=self._cache_directory, | ||
) | ||
command = ( | ||
"python -m pysqa.executor --cores " | ||
+ str(queue_adapter_kwargs["cores"]) | ||
+ " --path " | ||
+ str(self._cache_directory) | ||
) | ||
self._queue_id = self._queue_adapter.submit_job( | ||
working_directory=self._cache_directory, | ||
command=command, | ||
**queue_adapter_kwargs | ||
) | ||
self._process = Thread( | ||
target=find_executed_tasks, | ||
kwargs={ | ||
"future_queue": self._task_queue, | ||
"cache_directory": self._cache_directory, | ||
}, | ||
) | ||
self._process.start() | ||
|
||
def submit(self, fn, *args, **kwargs): | ||
funct_dict = serialize_funct(fn, *args, **kwargs) | ||
key = list(funct_dict.keys())[0] | ||
if key not in self._memory_dict.keys(): | ||
self._memory_dict[key] = Future() | ||
_ = write_to_file( | ||
funct_dict=funct_dict, state="in", cache_directory=self._cache_directory | ||
)[0] | ||
self._task_queue.put({key: self._memory_dict[key]}) | ||
return self._memory_dict[key] | ||
|
||
def shutdown(self, wait=True, *, cancel_futures=False): | ||
if cancel_futures: | ||
cancel_items_in_queue(que=self._task_queue) | ||
self._task_queue.put({"shutdown": True, "wait": wait}) | ||
self._queue_adapter.delete_job(process_id=self._queue_id) | ||
self._process.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import os | ||
import re | ||
import queue | ||
from concurrent.futures import Future | ||
|
||
import hashlib | ||
import cloudpickle | ||
|
||
|
||
def deserialize(funct_dict): | ||
try: | ||
return {k: cloudpickle.loads(v) for k, v in funct_dict.items()} | ||
except EOFError: | ||
return {} | ||
|
||
|
||
def find_executed_tasks(future_queue, cache_directory): | ||
task_memory_dict = {} | ||
while True: | ||
task_dict = {} | ||
try: | ||
task_dict = future_queue.get_nowait() | ||
except queue.Empty: | ||
pass | ||
if "shutdown" in task_dict.keys() and task_dict["shutdown"]: | ||
break | ||
else: | ||
_update_task_dict( | ||
task_dict=task_dict, | ||
task_memory_dict=task_memory_dict, | ||
cache_directory=cache_directory, | ||
) | ||
|
||
|
||
def read_from_file(file_name): | ||
name = file_name.split("/")[-1].split(".")[0] | ||
with open(file_name, "rb") as f: | ||
return {name: f.read()} | ||
|
||
|
||
def reload_previous_futures(future_queue, future_dict, cache_directory): | ||
file_lst = os.listdir(cache_directory) | ||
for f in file_lst: | ||
if f.endswith(".in.pl"): | ||
key = f.split(".in.pl")[0] | ||
future_dict[key] = Future() | ||
file_name_out = key + ".out.pl" | ||
if file_name_out in file_lst: | ||
_set_future( | ||
file_name=os.path.join(cache_directory, file_name_out), | ||
future=future_dict[key], | ||
) | ||
else: | ||
future_queue.put({key: future_dict[key]}) | ||
|
||
|
||
def serialize_result(result_dict): | ||
return {k: cloudpickle.dumps(v) for k, v in result_dict.items()} | ||
|
||
|
||
def serialize_funct(fn, *args, **kwargs): | ||
binary = cloudpickle.dumps({"fn": fn, "args": args, "kwargs": kwargs}) | ||
return {fn.__name__ + _get_hash(binary=binary): binary} | ||
|
||
|
||
def write_to_file(funct_dict, state, cache_directory): | ||
file_name_lst = [] | ||
for k, v in funct_dict.items(): | ||
file_name = _get_file_name(name=k, state=state) | ||
file_name_lst.append(file_name) | ||
with open(os.path.join(cache_directory, file_name), "wb") as f: | ||
f.write(v) | ||
return file_name_lst | ||
|
||
|
||
def _get_file_name(name, state): | ||
return name + "." + state + ".pl" | ||
|
||
|
||
def _get_hash(binary): | ||
# Remove specification of jupyter kernel from hash to be deterministic | ||
binary_no_ipykernel = re.sub(b"(?<=/ipykernel_)(.*)(?=/)", b"", binary) | ||
return str(hashlib.md5(binary_no_ipykernel).hexdigest()) | ||
|
||
|
||
def _set_future(file_name, future): | ||
values = deserialize(funct_dict=read_from_file(file_name=file_name)).values() | ||
if len(values) == 1: | ||
future.set_result(list(values)[0]) | ||
|
||
|
||
def _update_task_dict(task_dict, task_memory_dict, cache_directory): | ||
file_lst = os.listdir(cache_directory) | ||
for key, future in task_dict.items(): | ||
task_memory_dict[key] = future | ||
for key, future in task_memory_dict.items(): | ||
file_name_out = _get_file_name(name=key, state="out") | ||
if not future.done() and file_name_out in file_lst: | ||
_set_future( | ||
file_name=os.path.join(cache_directory, file_name_out), | ||
future=future, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.