Skip to content

Commit

Permalink
Add coarse-graining example (#98)
Browse files Browse the repository at this point in the history
* first draft

* track unwrapped and wrapped positions in MD code when cell is present, fix typo

* remove unused imports, update md length

* add link to data on Zenodo, change dataset filename
  • Loading branch information
shinkle-lanl committed Sep 9, 2024
1 parent 0004728 commit 10a9055
Show file tree
Hide file tree
Showing 6 changed files with 351 additions and 7 deletions.
7 changes: 7 additions & 0 deletions examples/coarse-graining/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
The files in this directory allow one to train and run MD with a coarse-grained HIPNN model. Details of this model can be found in the paper "Thermodynamic Transferability in Coarse-Grained Force Fields using Graph Neural Networks" by Shinkle et. al. available at <https://doi.org/10.48550/arXiv.2406.12112>.

Before executing these files, one must download the training data from <https://doi.org/10.5281/zenodo.13717306>. The file should be placed at `datasets/cg_methanol_trajectory.npz` where `datasets/` is at the same level as the hippynn repository.

1. Run `cg_training.py` to generate a model. This model will be saved in `hippynn/examples/coarse-graining/model`.
2. Run `cg_md.py` to run MD using the model trained in step 1. The resulting trajectory will be saved in `hippynn/examples/coarse-graining/md_results`.

115 changes: 115 additions & 0 deletions examples/coarse-graining/cg_md.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os

import numpy as np
import torch

from ase import units

from hippynn.experiment.serialization import load_checkpoint_from_cwd
from hippynn.graphs.predictor import Predictor
from hippynn.molecular_dynamics.md import (
Variable,
NullUpdater,
LangevinDynamics,
MolecularDynamics,
)
from hippynn.tools import active_directory

default_dtype=torch.float
torch.set_default_dtype(default_dtype)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load initial conditions
training_data_file = os.path.join(os.pardir,os.pardir,os.pardir,"datasets","cg_methanol_trajectory.npz")

with np.load(training_data_file) as data:
cell = torch.as_tensor(data["cells"][-1], dtype=default_dtype, device=device)[None,...]
masses = torch.as_tensor(data["masses"][-1], dtype=default_dtype, device=device)[None,...]
positions = torch.as_tensor(data["positions"][-1], dtype=default_dtype, device=device)[None,...]
velocities = torch.as_tensor(data["velocities"][-1], dtype=default_dtype, device=device)[None,...]
species = torch.as_tensor(data["species"][-1], dtype=torch.int, device=device)[None,...]

positions_variable = Variable(
name="positions",
data={
"position": positions,
"velocity": velocities,
"mass": masses,
"acceleration": torch.zeros_like(velocities),
"cell": cell,
},
model_input_map={"positions": "position"},
device=device,
)

position_updater = LangevinDynamics(
force_db_name="forces",
temperature=700,
frix=6,
units_force=units.kcal / units.mol / units.Ang,
units_acc=units.Ang / ((1000 * units.fs)**2),
seed=1993,
)
positions_variable.updater = position_updater

cell_variable = Variable(
name="cell",
data={"cell": cell},
model_input_map={"cells": "cell"},
device=device,
updater=NullUpdater(),
)

species_variable = Variable(
name="species",
data={"species": species},
model_input_map={"species": "species"},
device=device,
updater=NullUpdater(),
)

# Load model
with active_directory("model"):
check = load_checkpoint_from_cwd(model_device=device, restart_db=False)

repulse = check["training_modules"].model.node_from_name("repulse")
energy = check["training_modules"].model.node_from_name("sys_energy")

model = Predictor.from_graph(
check["training_modules"].model,
additional_outputs=[
repulse.mol_energies,
energy,
],
)

model = Predictor.from_graph(check["training_modules"].model)

model.to(default_dtype)
model.to(device)

pairs = model.graph.node_from_name("pairs")
pairs.skin = 3 # see hippynn.graphs.nodes.pairs.KDTreePairsMemory documentation

# Run MD
with active_directory("md_results"):
emdee = MolecularDynamics(
variables=[positions_variable, species_variable, cell_variable],
model=model,
)

emdee.run(dt=0.001, n_steps=20000)
emdee.run(dt=0.001, n_steps=50000, record_every=50)

data = emdee.get_data()
np.savez("hippynn_cg_trajectory.npz",
positions = data["positions_position"],
velocities = data["positions_velocity"],
masses = data["positions_mass"],
accelerations = data["positions_acceleration"],
cells = data["positions_cell"],
unwrapped_positions = data["positions_unwrapped_position"],
forces = data["positions_force"],
species = data["species_species"],
)
162 changes: 162 additions & 0 deletions examples/coarse-graining/cg_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import os

import numpy as np
import torch

from hippynn.databases import NPZDatabase
from hippynn.experiment import SetupParams, setup_and_train
from hippynn.experiment.assembly import assemble_for_training
from hippynn.experiment.controllers import RaiseBatchSizeOnPlateau, PatienceController
from hippynn.graphs import IdxType
from hippynn.graphs.nodes import loss
from hippynn.graphs.nodes.base.algebra import AddNode
from hippynn.graphs.nodes.indexers import acquire_encoding_padding
from hippynn.graphs.nodes.inputs import SpeciesNode, PositionsNode, CellNode
from hippynn.graphs.nodes.networks import HipnnQuad
from hippynn.graphs.nodes.pairs import KDTreePairsMemory
from hippynn.graphs.nodes.physics import MultiGradientNode
from hippynn.graphs.nodes.targets import HEnergyNode
from hippynn.plotting import PlotMaker, Hist2D, SensitivityPlot
from hippynn.tools import active_directory

from repulsive_potential import RepulsivePotentialNode

training_data_file = os.path.join(os.pardir,os.pardir,os.pardir,"datasets","cg_methanol_trajectory.npz")

with np.load(training_data_file) as data:
idx = np.where(data["rdf_values"] > 0.01)[0][0]
repulsive_potential_taper_point = data["rdf_bins"][idx]
repulsive_potential_strength = np.abs(data["forces"]).mean()

## Initialize needed nodes for network
# Network input nodes
species = SpeciesNode(name="species", db_name="species")
positions = PositionsNode(name="positions", db_name="positions")
cells = CellNode(name="cells", db_name="cells")

# Network hyperparameters
network_params = {
"possible_species": [0,1],
"n_features": 128,
"n_sensitivities": 20,
"dist_soft_min": 2.0,
"dist_soft_max": 13.0,
"dist_hard_max": 15.0,
"n_interaction_layers": 1,
"n_atom_layers": 3,
"sensitivity_type": "inverse",
"resnet": True,
}

# Species encoder
enc, pdx = acquire_encoding_padding([species], species_set=[0,1])

# Pair finder
pair_finder = KDTreePairsMemory(
"pairs",
(positions, enc, pdx, cells),
dist_hard_max=network_params["dist_hard_max"],
skin=0,
)

# HIP-NN-TS node with l=2
network = HipnnQuad(
"HIPNN", (pdx, pair_finder), module_kwargs=network_params, periodic=True
)

# Network energy prediction
henergy = HEnergyNode("HEnergy", parents=(network,))

# Repulsive potential
repulse = RepulsivePotentialNode(
"repulse",
(pair_finder, pdx),
taper_point=repulsive_potential_taper_point,
strength=repulsive_potential_strength,
dr=0.15,
perc=0.05,
)

# Combined energy prediction
energy = AddNode(henergy.main_output, repulse.mol_energies)
energy.name = "energies"
energy._index_state = IdxType.Molecules

sys_energy = energy.main_output
sys_energy.name = "sys_energy"

# Force node
grad = MultiGradientNode("forces", energy, (positions,), signs=-1)
force = grad.children[0]
force.db_name = "forces"

## Define losses
force_rsq = loss.Rsq.of_node(force)
force_rmse = loss.MSELoss.of_node(force) ** (1 / 2)
force_mae = loss.MAELoss.of_node(force)
total_loss = force_rmse + force_mae

validation_losses = {
"ForceRMSE": force_rmse,
"ForceMAE": force_mae,
"ForceRsq": force_rsq,
"TotalLoss": total_loss,
}

plotters = [
Hist2D.compare(force, saved="forces", shown=False),
SensitivityPlot(
network.torch_module.sensitivity_layers[0], saved="sensitivity", shown=False
),
]

plot_maker = PlotMaker(
*plotters,
plot_every=10,
)

## Build network
training_modules, db_info = assemble_for_training(
total_loss, validation_losses, plot_maker=plot_maker
)

## Load training data
database = NPZDatabase(
training_data_file,
seed=0,
**db_info,
valid_size=0.1,
test_size=0.1,
)

## Set up optimizer
optimizer = torch.optim.Adam(training_modules.model.parameters(), lr=1e-3)

scheduler = RaiseBatchSizeOnPlateau(
optimizer=optimizer,
max_batch_size=64,
patience=10,
factor=0.5,
)

controller = PatienceController(
optimizer=optimizer,
scheduler=scheduler,
batch_size=1,
fraction_train_eval=0.2,
eval_batch_size=1,
max_epochs=200,
termination_patience=20,
stopping_key="TotalLoss",
)

experiment_params = SetupParams(controller=controller)

## Train!
with active_directory("model"):
metric_tracker = setup_and_train(
training_modules=training_modules,
database=database,
setup_params=experiment_params,
)

58 changes: 58 additions & 0 deletions examples/coarse-graining/repulsive_potential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import math
import torch

from hippynn.graphs import IdxType
from hippynn.graphs.nodes.base import ExpandParents
from hippynn.graphs.nodes.base.definition_helpers import AutoKw
from hippynn.graphs.nodes.base.multi import MultiNode
from hippynn.graphs.nodes.tags import PairIndexer, AtomIndexer
from hippynn.layers import pairs

## Define repulsive potential node for hippynn
class RepulsivePotential(torch.nn.Module):
def __init__(self, taper_point, strength, dr, perc):
'''
Let F(r) be the force between two particles of distance r generated
by this potential. Then
F(taper_point) = perc * strength
F(taper_point - dr) = strength
Eg. If taper_point=3, strength=1, dr=0.5, and perc=0.01, then
F(3) = 0.01
F(2.5) = 1
'''
super().__init__()
self.t = taper_point
self.s = strength
self.d = dr
self.p = perc

self.a = (1/self.d)*math.log(1/self.p)
self.g = -1 * self.s * self.p * math.exp(self.a * self.t) / self.a

self.summer = pairs.MolPairSummer()

def forward(self, pair_dist, pair_first, mol_index, n_molecules):
atom_energies = -1 * self.g * torch.exp(-1 * self.a * pair_dist)
mol_energies = self.summer(atom_energies, mol_index, n_molecules, pair_first)
return mol_energies, atom_energies,

class RepulsivePotentialNode(ExpandParents, AutoKw, MultiNode):
_input_names = "pair_dist", "pair_first", "mol_index", "n_molecules"
_output_names = "mol_energies", "atom_energies",
_auto_module_class = RepulsivePotential
_output_index_states = IdxType.Molecules, IdxType.Pair,

@_parent_expander.match(PairIndexer, AtomIndexer)
def expansion(self, pairfinder, pidxer, **kwargs):
return pairfinder.pair_dist, pairfinder.pair_first, pidxer.mol_index, pidxer.n_molecules

def __init__(self, name, parents, taper_point, strength, dr, perc, module="auto"):
self.module_kwargs = {
"taper_point": taper_point,
"strength": strength,
"dr": dr,
"perc": perc,
}
parents = self.expand_parents(parents, module="auto")
super().__init__(name, parents, module=module)
2 changes: 1 addition & 1 deletion hippynn/experiment/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def load_checkpoint(
:param structure_fname: name of the structure file
:param state_fname: name of the state file
:param restart_db: restore database or not, defaults to True
:param restart_db: restore database or not, defaults to False
:param map_location: device mapping argument for ``torch.load``, defaults to None
:param model_device: automatically handle device mapping. Defaults to None, defaults to None
:return: experiment structure
Expand Down
14 changes: 8 additions & 6 deletions hippynn/molecular_dynamics/md.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from functools import singledispatchmethod
from copy import copy

import numpy as np
import torch
Expand All @@ -9,7 +10,6 @@
from ..graphs import Predictor
from ..layers.pairs.periodic import wrap_systems_torch


class Variable:
"""
Tracks the state of a quantity (eg. position, cell, species,
Expand Down Expand Up @@ -307,11 +307,12 @@ def pre_step(self, dt):

self.variable.data["position"] = self.variable.data["position"] + self.variable.data["velocity"] * dt

try:
_, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only used for discarded outputs; can be set arbitrarily
except KeyError:
pass

if "cell" in self.variable.data.keys():
_, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only impacts unused outputs; can be set arbitrarily
try:
self.variable.data["unwrapped_position"] = self.variable.data["unwrapped_position"] + self.variable.data["velocity"] * dt
except KeyError:
self.variable.data["unwrapped_position"] = copy(self.variable.data["position"])
def post_step(self, dt, model_outputs):
"""Updates to variables performed during each step of MD simulation after HIPNN model evaluation
Expand Down Expand Up @@ -406,6 +407,7 @@ def model(self, model):
+ f" Entries in the 'model_input_map' should have the form 'hipnn-db_name: variable-data-key' where 'hipnn-db_name'"
+ f" refers to the db_name of an input for the hippynn Predictor model,"
+ f" and 'variable-data-key' corresponds to a key in the 'data' dictionary of one of the Variables."
+ f" Currently assigned db_names are: {variable_data_db_names}."
)
self._model = model

Expand Down

0 comments on commit 10a9055

Please sign in to comment.