From 6e93d4662df07b9d815c714a1ef7b2ddc6480415 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 4 Jul 2023 14:33:05 +0200 Subject: [PATCH 01/10] added apax ensemble model --- ipsuite/models/__init__.pyi | 4 +-- ipsuite/models/apax.py | 54 +++++++++++++++++++++++++++++++++++++ ipsuite/nodes.py | 1 + 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/ipsuite/models/__init__.pyi b/ipsuite/models/__init__.pyi index 5ee4300e..4d42ffe6 100644 --- a/ipsuite/models/__init__.pyi +++ b/ipsuite/models/__init__.pyi @@ -1,8 +1,8 @@ -from .apax import Apax +from .apax import Apax, ApaxEnsemble from .base import MLModel from .ensemble import EnsembleModel from .gap import GAP from .mace_model import MACE from .nequip import Nequip -__all__ = ["MLModel", "EnsembleModel", "GAP", "Nequip", "MACE", "Apax"] +__all__ = ["MLModel", "EnsembleModel", "GAP", "Nequip", "MACE", "Apax", "ApaxEnsemble"] diff --git a/ipsuite/models/apax.py b/ipsuite/models/apax.py index bdba7fab..2a57984f 100644 --- a/ipsuite/models/apax.py +++ b/ipsuite/models/apax.py @@ -1,17 +1,22 @@ import logging import pathlib import shutil +import typing +from uuid import uuid4 import ase.io import pandas as pd +from tqdm import tqdm import yaml import zntrack.utils from jax.config import config from zntrack import dvc, zn +from ipsuite import base from ipsuite import utils from ipsuite.models.base import MLModel from ipsuite.static_data import STATIC_PATH +from ipsuite.utils.ase_sim import freeze_copy_atoms from ipsuite.utils.helpers import check_duplicate_keys log = logging.getLogger(__name__) @@ -114,3 +119,52 @@ def get_calculator(self, **kwargs): self._handle_parameter_file() return ASECalculator(model_dir=self.model_directory) + + + +class ApaxEnsemble(base.IPSNode): + models: typing.List[Apax] = zntrack.zn.deps() + + uuid = zntrack.zn.outs() # to connect this Node to other Nodes it requires an output. + + def run(self) -> None: + self.uuid = str(uuid4()) + + def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator: + """Property to return a model specific ase calculator object. + + Returns + ------- + calc: + ase calculator object + """ + from apax.md import ASECalculator + param_files = [m._parameter["data"]["model_path"] for m in self.models] + + calc = ASECalculator(param_files[0]) + return calc + + + + def predict(self, atoms_list: typing.List[ase.Atoms]) -> typing.List[ase.Atoms]: + """Predict energy, forces and stresses. + + based on what was used to train for given atoms objects. + + Parameters + ---------- + atoms_list: typing.List[ase.Atoms] + list of atoms objects to predict on + + Returns + ------- + typing.List[ase.Atoms] + Atoms with updated calculators + """ + calc = self.get_calculator() + result = [] + for atoms in tqdm(atoms_list, ncols=120): + atoms.calc = calc + atoms.get_potential_energy() + result.append(freeze_copy_atoms(atoms)) + return result \ No newline at end of file diff --git a/ipsuite/nodes.py b/ipsuite/nodes.py index 064ea6c2..996a591c 100644 --- a/ipsuite/nodes.py +++ b/ipsuite/nodes.py @@ -13,6 +13,7 @@ class _Nodes: MACE = "ipsuite.models.MACE" Nequip = "ipsuite.models.Nequip" Apax = "ipsuite.models.Apax" + ApaxEnsemble = "ipsuite.models.ApaxEnsemble" # Configuration Selection IndexSelection = "ipsuite.configuration_selection.IndexSelection" From 4908bffc2c441ed61d5a42d88f3b5678e680ef0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 10 Jul 2023 10:50:47 +0200 Subject: [PATCH 02/10] added test for vmapped apax emsemble --- ipsuite/models/apax.py | 2 +- tests/integration/apax_minimal.yaml | 3 +- tests/integration/apax_minimal2.yaml | 22 +++++++++++ tests/integration/test_i_apax.py | 56 ++++++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 tests/integration/apax_minimal2.yaml diff --git a/ipsuite/models/apax.py b/ipsuite/models/apax.py index 2a57984f..cd8e3cb9 100644 --- a/ipsuite/models/apax.py +++ b/ipsuite/models/apax.py @@ -141,7 +141,7 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator: from apax.md import ASECalculator param_files = [m._parameter["data"]["model_path"] for m in self.models] - calc = ASECalculator(param_files[0]) + calc = ASECalculator(param_files) return calc diff --git a/tests/integration/apax_minimal.yaml b/tests/integration/apax_minimal.yaml index 0313413c..f5947400 100644 --- a/tests/integration/apax_minimal.yaml +++ b/tests/integration/apax_minimal.yaml @@ -6,9 +6,10 @@ data: valid_batch_size: 2 model: - nn: [512,512] + nn: [256,256] n_basis: 7 n_radial: 5 + calc_stress: true metrics: - name: energy diff --git a/tests/integration/apax_minimal2.yaml b/tests/integration/apax_minimal2.yaml new file mode 100644 index 00000000..65c029d3 --- /dev/null +++ b/tests/integration/apax_minimal2.yaml @@ -0,0 +1,22 @@ +n_epochs: 5 +seed: 667 + +data: + batch_size: 1 + valid_batch_size: 2 + +model: + nn: [256,256] + n_basis: 7 + n_radial: 5 + calc_stress: true + +metrics: + - name: energy + reductions: [mae] + - name: forces + reductions: [mae] + +loss: + - name: energy + - name: forces diff --git a/tests/integration/test_i_apax.py b/tests/integration/test_i_apax.py index 41cc4472..cf63795a 100644 --- a/tests/integration/test_i_apax.py +++ b/tests/integration/test_i_apax.py @@ -58,3 +58,59 @@ def test_model_training(proj_path, traj_file): assert isinstance(val, float) for val in analysis.forces.values(): assert isinstance(val, float) + + +def test_apax_ensemble(proj_path, traj_file): + shutil.copy(TEST_PATH / "apax_minimal.yaml", proj_path / "apax_minimal.yaml") + shutil.copy(TEST_PATH / "apax_minimal2.yaml", proj_path / "apax_minimal2.yaml") + + thermostat = ips.calculators.LangevinThermostat( + time_step=1.0, temperature=100.0, friction=0.01 + ) + + with ips.Project(automatic_node_names=True) as project: + raw_data = ips.AddData(file=traj_file, name="raw_data") + train_selection = UniformEnergeticSelection( + data=raw_data.atoms, n_configurations=10, name="data" + ) + + val_selection = UniformEnergeticSelection( + data=train_selection.excluded_atoms, n_configurations=8, name="val_data" + ) + + model1 = Apax( + config="apax_minimal.yaml", + data=train_selection.atoms, + validation_data=val_selection.atoms, + ) + + model2 = Apax( + config="apax_minimal2.yaml", + data=train_selection.atoms, + validation_data=val_selection.atoms, + ) + + ensemble_model = ips.models.ApaxEnsemble([model1, model2]) + + md = ips.calculators.ASEMD( + data=raw_data.atoms, + model=ensemble_model, + thermostat=thermostat, + steps=20, + sampling_rate=1, + ) + + uncertainty_selection = ips.configuration_selection.ThresholdSelection( + data=md, n_configurations=1, threshold=0.0001 + ) + + prediction = ips.analysis.Prediction(data=raw_data, model=ensemble_model) + prediction_metrics = ips.analysis.PredictionMetrics(data=prediction) + + project.run(eager=True) + + # uncertainty_selection.load() + # md.load() + + uncertainties = [x.calc.results["energy_uncertainty"] for x in md.atoms] + assert [md.atoms[np.argmax(uncertainties)]] == uncertainty_selection.atoms From 3b921be0d66ab361650f988b0b433bb30cc60314 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Jul 2023 09:08:21 +0000 Subject: [PATCH 03/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ipsuite/models/apax.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/ipsuite/models/apax.py b/ipsuite/models/apax.py index cd8e3cb9..cb61b8b7 100644 --- a/ipsuite/models/apax.py +++ b/ipsuite/models/apax.py @@ -6,14 +6,13 @@ import ase.io import pandas as pd -from tqdm import tqdm import yaml import zntrack.utils from jax.config import config +from tqdm import tqdm from zntrack import dvc, zn -from ipsuite import base -from ipsuite import utils +from ipsuite import base, utils from ipsuite.models.base import MLModel from ipsuite.static_data import STATIC_PATH from ipsuite.utils.ase_sim import freeze_copy_atoms @@ -121,7 +120,6 @@ def get_calculator(self, **kwargs): return ASECalculator(model_dir=self.model_directory) - class ApaxEnsemble(base.IPSNode): models: typing.List[Apax] = zntrack.zn.deps() @@ -139,13 +137,12 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator: ase calculator object """ from apax.md import ASECalculator + param_files = [m._parameter["data"]["model_path"] for m in self.models] calc = ASECalculator(param_files) return calc - - def predict(self, atoms_list: typing.List[ase.Atoms]) -> typing.List[ase.Atoms]: """Predict energy, forces and stresses. @@ -167,4 +164,4 @@ def predict(self, atoms_list: typing.List[ase.Atoms]) -> typing.List[ase.Atoms]: atoms.calc = calc atoms.get_potential_energy() result.append(freeze_copy_atoms(atoms)) - return result \ No newline at end of file + return result From eb6fb7940ada4f9f1cdc6be623b2f5c7f3bcfb78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 1 Aug 2023 10:01:14 +0200 Subject: [PATCH 04/10] switched test to graph mode --- tests/integration/test_i_apax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_i_apax.py b/tests/integration/test_i_apax.py index cf63795a..00a2f16a 100644 --- a/tests/integration/test_i_apax.py +++ b/tests/integration/test_i_apax.py @@ -107,10 +107,10 @@ def test_apax_ensemble(proj_path, traj_file): prediction = ips.analysis.Prediction(data=raw_data, model=ensemble_model) prediction_metrics = ips.analysis.PredictionMetrics(data=prediction) - project.run(eager=True) + project.run() - # uncertainty_selection.load() - # md.load() + uncertainty_selection.load() + md.load() uncertainties = [x.calc.results["energy_uncertainty"] for x in md.atoms] assert [md.atoms[np.argmax(uncertainties)]] == uncertainty_selection.atoms From 809b890edbe90e2e5689c7a56a1049dcc338b8a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 1 Aug 2023 10:01:35 +0200 Subject: [PATCH 05/10] removed deprecated methods, moved up apax import --- ipsuite/models/apax.py | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/ipsuite/models/apax.py b/ipsuite/models/apax.py index cb61b8b7..8af6e1e7 100644 --- a/ipsuite/models/apax.py +++ b/ipsuite/models/apax.py @@ -4,6 +4,8 @@ import typing from uuid import uuid4 +from apax.md import ASECalculator +from apax.train.run import run as apax_run import ase.io import pandas as pd import yaml @@ -84,8 +86,6 @@ def _handle_parameter_file(self): def train_model(self): """Train the model using `apax.train.run`""" - from apax.train.run import run as apax_run - apax_run(self._parameter, log_file=self.train_log_file) def move_metrics(self): @@ -114,7 +114,6 @@ def run(self): def get_calculator(self, **kwargs): """Get a apax ase calculator""" - from apax.md import ASECalculator self._handle_parameter_file() return ASECalculator(model_dir=self.model_directory) @@ -123,10 +122,8 @@ def get_calculator(self, **kwargs): class ApaxEnsemble(base.IPSNode): models: typing.List[Apax] = zntrack.zn.deps() - uuid = zntrack.zn.outs() # to connect this Node to other Nodes it requires an output. - def run(self) -> None: - self.uuid = str(uuid4()) + pass def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator: """Property to return a model specific ase calculator object. @@ -143,25 +140,3 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator: calc = ASECalculator(param_files) return calc - def predict(self, atoms_list: typing.List[ase.Atoms]) -> typing.List[ase.Atoms]: - """Predict energy, forces and stresses. - - based on what was used to train for given atoms objects. - - Parameters - ---------- - atoms_list: typing.List[ase.Atoms] - list of atoms objects to predict on - - Returns - ------- - typing.List[ase.Atoms] - Atoms with updated calculators - """ - calc = self.get_calculator() - result = [] - for atoms in tqdm(atoms_list, ncols=120): - atoms.calc = calc - atoms.get_potential_energy() - result.append(freeze_copy_atoms(atoms)) - return result From d8f7224e02be885500cc44c975110d9de357ebfe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Aug 2023 08:05:46 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ipsuite/models/apax.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ipsuite/models/apax.py b/ipsuite/models/apax.py index 8af6e1e7..12022ed9 100644 --- a/ipsuite/models/apax.py +++ b/ipsuite/models/apax.py @@ -4,12 +4,12 @@ import typing from uuid import uuid4 -from apax.md import ASECalculator -from apax.train.run import run as apax_run import ase.io import pandas as pd import yaml import zntrack.utils +from apax.md import ASECalculator +from apax.train.run import run as apax_run from jax.config import config from tqdm import tqdm from zntrack import dvc, zn @@ -139,4 +139,3 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator: calc = ASECalculator(param_files) return calc - From 155dd448aa4a3daf0cfef16e790d546c627aea47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 1 Aug 2023 10:15:35 +0200 Subject: [PATCH 07/10] linting --- ipsuite/models/apax.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/ipsuite/models/apax.py b/ipsuite/models/apax.py index 12022ed9..37a0475d 100644 --- a/ipsuite/models/apax.py +++ b/ipsuite/models/apax.py @@ -2,7 +2,6 @@ import pathlib import shutil import typing -from uuid import uuid4 import ase.io import pandas as pd @@ -11,13 +10,11 @@ from apax.md import ASECalculator from apax.train.run import run as apax_run from jax.config import config -from tqdm import tqdm from zntrack import dvc, zn from ipsuite import base, utils from ipsuite.models.base import MLModel from ipsuite.static_data import STATIC_PATH -from ipsuite.utils.ase_sim import freeze_copy_atoms from ipsuite.utils.helpers import check_duplicate_keys log = logging.getLogger(__name__) From fa89262ab16fafe72c08496fc7bb09f5f71a72b1 Mon Sep 17 00:00:00 2001 From: PythonFZ Date: Tue, 1 Aug 2023 11:16:52 +0200 Subject: [PATCH 08/10] fix issue with `apax._model` not being loaded --- ipsuite/models/apax.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ipsuite/models/apax.py b/ipsuite/models/apax.py index 37a0475d..0242558d 100644 --- a/ipsuite/models/apax.py +++ b/ipsuite/models/apax.py @@ -67,6 +67,10 @@ def _post_init_(self): self.validation_data = utils.helpers.get_deps_if_node( self.validation_data, "atoms" ) + self._handle_parameter_file() + + def _post_load_(self) -> None: + self._handle_parameter_file() def _handle_parameter_file(self): self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text()) @@ -100,7 +104,6 @@ def run(self): ase.io.write(self.train_data_file, self.data) ase.io.write(self.validation_data_file, self.validation_data) - self._handle_parameter_file() self.train_model() self.move_metrics() @@ -112,7 +115,6 @@ def run(self): def get_calculator(self, **kwargs): """Get a apax ase calculator""" - self._handle_parameter_file() return ASECalculator(model_dir=self.model_directory) From df37410baa1104c6bf251e5451d465591adabf76 Mon Sep 17 00:00:00 2001 From: Fabian Zills <46721498+PythonFZ@users.noreply.github.com> Date: Tue, 1 Aug 2023 11:30:09 +0200 Subject: [PATCH 09/10] Update apax.py --- ipsuite/models/apax.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ipsuite/models/apax.py b/ipsuite/models/apax.py index 0242558d..b533e219 100644 --- a/ipsuite/models/apax.py +++ b/ipsuite/models/apax.py @@ -132,8 +132,7 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator: calc: ase calculator object """ - from apax.md import ASECalculator - + param_files = [m._parameter["data"]["model_path"] for m in self.models] calc = ASECalculator(param_files) From e30c2910abd312d37ed38dd22183bbf1922f7c61 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Aug 2023 09:30:36 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ipsuite/models/apax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipsuite/models/apax.py b/ipsuite/models/apax.py index b533e219..1d58840e 100644 --- a/ipsuite/models/apax.py +++ b/ipsuite/models/apax.py @@ -132,7 +132,7 @@ def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator: calc: ase calculator object """ - + param_files = [m._parameter["data"]["model_path"] for m in self.models] calc = ASECalculator(param_files)