diff --git a/airflow/dags/DAG_PHT_run_train.py b/airflow/dags/DAG_PHT_run_train.py index c2cd9f9..788d05e 100644 --- a/airflow/dags/DAG_PHT_run_train.py +++ b/airflow/dags/DAG_PHT_run_train.py @@ -283,28 +283,37 @@ def execute_container(train_state): print(container_output) raise ValueError(f"The train execution returned a non zero exit code: {exit_code}") # - # def _copy(from_cont, from_path, to_cont, to_path): - # """ - # Copies a file from one container to another container - # :param from_cont: - # :param from_path: - # :param to_cont: - # :param to_path: - # :return: - # """ - # tar_stream, _ = from_cont.get_archive(from_path) - # to_cont.put_archive(os.path.dirname(to_path), tar_stream) - # - # base_image = ':'.join([train_state["repository"], 'base']) - # to_container = client.containers.create(base_image) - # # Copy results to base image - # _copy(from_cont=container, - # from_path="/opt/pht_results", - # to_cont=to_container, - # to_path="/opt/pht_results") - - container.commit(repository=train_state["repository"], tag=train_state["tag"]) + def _copy(from_cont, from_path, to_cont, to_path): + """ + Copies a file from one container to another container + :param from_cont: + :param from_path: + :param to_cont: + :param to_path: + :return: + """ + tar_stream, _ = from_cont.get_archive(from_path) + to_cont.put_archive(os.path.dirname(to_path), tar_stream) + + base_image = ':'.join([train_state["repository"], 'base']) + to_container = client.containers.create(base_image) + # Copy results to base image + _copy(from_cont=container, + from_path="/opt/pht_train", + to_cont=to_container, + to_path="/opt/pht_train") + _copy(from_cont=container, + from_path="/opt/pht_results", + to_cont=to_container, + to_path="/opt/pht_results") + _copy(from_cont=container, + from_path="/opt/train_config.json", + to_cont=to_container, + to_path="/opt/train_config.json") + + to_container.commit(repository=train_state["repository"], tag=train_state["tag"]) container.remove(v=True, force=True) + to_container.remove(v=True, force=True) if exit_code != 0: raise ValueError(f"The train execution returned a non zero exit code: {exit_code}")