Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apax ensemble #149

Merged
merged 18 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ipsuite/models/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .apax import Apax
from .apax import Apax, ApaxEnsemble
from .base import MLModel
from .ensemble import EnsembleModel
from .gap import GAP
from .mace_model import MACE
from .nequip import Nequip

__all__ = ["MLModel", "EnsembleModel", "GAP", "Nequip", "MACE", "Apax"]
__all__ = ["MLModel", "EnsembleModel", "GAP", "Nequip", "MACE", "Apax", "ApaxEnsemble"]
36 changes: 30 additions & 6 deletions ipsuite/models/apax.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import logging
import pathlib
import shutil
import typing

import ase.io
import pandas as pd
import yaml
import zntrack.utils
from apax.md import ASECalculator
from apax.train.run import run as apax_run
from jax.config import config
from zntrack import dvc, zn

from ipsuite import utils
from ipsuite import base, utils
from ipsuite.models.base import MLModel
from ipsuite.static_data import STATIC_PATH
from ipsuite.utils.helpers import check_duplicate_keys
Expand Down Expand Up @@ -64,6 +67,10 @@ def _post_init_(self):
self.validation_data = utils.helpers.get_deps_if_node(
self.validation_data, "atoms"
)
self._handle_parameter_file()

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())
Expand All @@ -80,8 +87,6 @@ def _handle_parameter_file(self):

def train_model(self):
"""Train the model using `apax.train.run`"""
from apax.train.run import run as apax_run

apax_run(self._parameter, log_file=self.train_log_file)

def move_metrics(self):
Expand All @@ -99,7 +104,6 @@ def run(self):

ase.io.write(self.train_data_file, self.data)
ase.io.write(self.validation_data_file, self.validation_data)
self._handle_parameter_file()

self.train_model()
self.move_metrics()
Expand All @@ -110,7 +114,27 @@ def run(self):

def get_calculator(self, **kwargs):
"""Get a apax ase calculator"""
from apax.md import ASECalculator

self._handle_parameter_file()
return ASECalculator(model_dir=self.model_directory)


class ApaxEnsemble(base.IPSNode):
models: typing.List[Apax] = zntrack.zn.deps()

def run(self) -> None:
pass

def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
"""Property to return a model specific ase calculator object.

Returns
-------
calc:
ase calculator object
"""
from apax.md import ASECalculator
PythonFZ marked this conversation as resolved.
Show resolved Hide resolved

param_files = [m._parameter["data"]["model_path"] for m in self.models]

calc = ASECalculator(param_files)
return calc
1 change: 1 addition & 0 deletions ipsuite/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class _Nodes:
MACE = "ipsuite.models.MACE"
Nequip = "ipsuite.models.Nequip"
Apax = "ipsuite.models.Apax"
ApaxEnsemble = "ipsuite.models.ApaxEnsemble"

# Configuration Selection
IndexSelection = "ipsuite.configuration_selection.IndexSelection"
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/apax_minimal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ data:
valid_batch_size: 2

model:
nn: [512,512]
nn: [256,256]
n_basis: 7
n_radial: 5
calc_stress: true

metrics:
- name: energy
Expand Down
22 changes: 22 additions & 0 deletions tests/integration/apax_minimal2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
n_epochs: 5
seed: 667

data:
batch_size: 1
valid_batch_size: 2

model:
nn: [256,256]
n_basis: 7
n_radial: 5
calc_stress: true

metrics:
- name: energy
reductions: [mae]
- name: forces
reductions: [mae]

loss:
- name: energy
- name: forces
56 changes: 56 additions & 0 deletions tests/integration/test_i_apax.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,59 @@ def test_model_training(proj_path, traj_file):
assert isinstance(val, float)
for val in analysis.forces.values():
assert isinstance(val, float)


def test_apax_ensemble(proj_path, traj_file):
shutil.copy(TEST_PATH / "apax_minimal.yaml", proj_path / "apax_minimal.yaml")
shutil.copy(TEST_PATH / "apax_minimal2.yaml", proj_path / "apax_minimal2.yaml")

thermostat = ips.calculators.LangevinThermostat(
time_step=1.0, temperature=100.0, friction=0.01
)

with ips.Project(automatic_node_names=True) as project:
raw_data = ips.AddData(file=traj_file, name="raw_data")
train_selection = UniformEnergeticSelection(
data=raw_data.atoms, n_configurations=10, name="data"
)

val_selection = UniformEnergeticSelection(
data=train_selection.excluded_atoms, n_configurations=8, name="val_data"
)

model1 = Apax(
config="apax_minimal.yaml",
data=train_selection.atoms,
validation_data=val_selection.atoms,
)

model2 = Apax(
config="apax_minimal2.yaml",
data=train_selection.atoms,
validation_data=val_selection.atoms,
)

ensemble_model = ips.models.ApaxEnsemble([model1, model2])

md = ips.calculators.ASEMD(
data=raw_data.atoms,
model=ensemble_model,
thermostat=thermostat,
steps=20,
sampling_rate=1,
)

uncertainty_selection = ips.configuration_selection.ThresholdSelection(
data=md, n_configurations=1, threshold=0.0001
)

prediction = ips.analysis.Prediction(data=raw_data, model=ensemble_model)
prediction_metrics = ips.analysis.PredictionMetrics(data=prediction)

project.run()

uncertainty_selection.load()
md.load()

uncertainties = [x.calc.results["energy_uncertainty"] for x in md.atoms]
assert [md.atoms[np.argmax(uncertainties)]] == uncertainty_selection.atoms
Loading