Skip to content

Commit

Permalink
update documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkle-lanl committed Jun 5, 2024
1 parent dfc2d50 commit 4638b00
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 26 deletions.
2 changes: 1 addition & 1 deletion examples/molecular_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
"coordinates": "position",
},
device=device,
updater=VelocityVerlet(force_key="force"),
updater=VelocityVerlet(force_db_name="force"),
)

# Define species and cell Variables
Expand Down
56 changes: 31 additions & 25 deletions hippynn/molecular_dynamics/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Variable:
"""
Tracks the state of a quantity (eg. position, cell, species,
volume) on each particle or each system in an MD simulation. Can
also hold additional data associated to that quantity (such as
also hold additional data associated to that quantity (such as its
velocity, acceleration, etc...)
"""

Expand All @@ -24,7 +24,7 @@ def __init__(
name: str,
data: dict[str, torch.Tensor],
model_input_map: dict[str, str] = dict(),
updater: VariableUpdater = None,
updater: _VariableUpdater = None,
device: torch.device = None,
dtype: torch.dtype = None,
) -> None:
Expand All @@ -35,15 +35,17 @@ def __init__(
name for variable
data : dict[str, torch.Tensor]
dictionary of tracked data in the form `value_name: value`
updater : VariableUpdater
object which will update the data of the Variable
over the course of the MD simulation
model_input_map : dict[str, str], optional
dictionary of correspondences between data tracked by Variable
and inputs to the HIP-NN model in the form
`hipnn-db_name: variable-data-key`, by default dict()
device : Union[str, torch.device], optional
updater : _VariableUpdater, optional
object which will update the data of the Variable
over the course of the MD simulation, by default None
device : torch.device, optional
device on which to keep data, by default None
dtype : torch.dtype, optional
dtype for float type data, by default None
"""
self.name = name
self.data = data
Expand Down Expand Up @@ -137,10 +139,10 @@ def _(self, arg: torch.dtype):
self.dtype = arg


class VariableUpdater:
class _VariableUpdater:
"""
Parent class for algorithms to make updates to the data of a Variable during
each step on an MD simulation.
Parent class for algorithms that make updates to the data of a Variable during
each step of an MD simulation.
Subclasses should redefine __init__, pre_step, post_step, and
required_variable_data as needed. The inputs to pre_step and post_step
Expand Down Expand Up @@ -176,7 +178,7 @@ def pre_step(self, dt):
dt : float
timestep
"""
raise NotImplementedError("All subclasses must implement this method.")
pass

def post_step(self, dt, model_outputs):
"""Updates to variables performed during each step of MD simulation
Expand All @@ -189,10 +191,10 @@ def post_step(self, dt, model_outputs):
model_outputs : dict
dictionary of HIPNN model outputs
"""
raise NotImplementedError("All subclasses must implement this method.")
pass


class NullUpdater(VariableUpdater):
class NullUpdater(_VariableUpdater):
"""
Makes no change to the variable data at each step of MD.
"""
Expand All @@ -203,7 +205,7 @@ def pre_step(self, dt):
def post_step(self, dt, model_outputs):
pass

class VelocityVerlet(VariableUpdater):
class VelocityVerlet(_VariableUpdater):
"""
Implements the Velocity Verlet algorithm
"""
Expand All @@ -212,15 +214,15 @@ class VelocityVerlet(VariableUpdater):

def __init__(
self,
force_key: str,
force_db_name: str,
units_force: float = ase.units.eV,
units_acc: float = ase.units.Ang / (1.0**2),
):
"""
Parameters
----------
force_key : str
key which will correspond to the force on the modified Variable
force_db_name : str
key which will correspond to the force on the corresponding Variable
in the HIPNN model output dictionary
units_force : float, optional
amount of eV equal to one in the units used for force output
Expand All @@ -231,7 +233,7 @@ def __init__(
amount of Ang/fs^2 equal to one in the units used for acceleration
in the corresponding Variable, by default units.Ang/(1.0 ** 2) = 1
"""
self.force_key = force_key
self.force_key = force_db_name
self.force_factor = units_force / units_acc

def pre_step(self, dt):
Expand Down Expand Up @@ -271,7 +273,7 @@ def post_step(self, dt, model_outputs):
self.variable.data["velocity"] = self.variable.data["velocity"] + 0.5 * dt * self.variable.data["acceleration"]


class LangevinDynamics(VariableUpdater):
class LangevinDynamics(_VariableUpdater):
"""
Implements the Langevin algorithm
"""
Expand All @@ -280,7 +282,7 @@ class LangevinDynamics(VariableUpdater):

def __init__(
self,
force_key: str,
force_db_name: str,
temperature: float,
frix: float,
units_force=ase.units.eV,
Expand All @@ -290,8 +292,8 @@ def __init__(
"""
Parameters
----------
force_key : str
key which will correspond to the force on the modified Variable
force_db_name : str
key which will correspond to the force on the corresponding Variable
in the HIPNN model output dictionary
temperature : float
temperature for Langevin algorithm
Expand All @@ -309,7 +311,7 @@ def __init__(
used to set seed for reproducibility, by default None
"""

self.force_key = force_key
self.force_key = force_db_name
self.force_factor = units_force / units_acc
self.temperature = temperature
self.frix = frix
Expand Down Expand Up @@ -381,6 +383,10 @@ def __init__(
list of Variable objects which will be tracked during simulation
model : Predictor
HIPNN Predictor
device : torch.device, optional
device to move variables and model to, by default None
dtype : torch.dtype, optional
dtype to convert all float type variable data and model parameters to, by default None
"""

self.variables = variables
Expand All @@ -400,7 +406,7 @@ def variables(self, variables):
variables = [variables]
for variable in variables:
if variable.updater is None:
raise ValueError(f"Variable with name {variable.name} does not have a VariableUpdater set.")
raise ValueError(f"Variable with name {variable.name} does not have a _VariableUpdater set.")

variable_names = [variable.name for variable in variables]
if len(variable_names) != len(set(variable_names)):
Expand Down Expand Up @@ -493,7 +499,6 @@ def _step(
return model_outputs

def _update_data(self, model_outputs: dict):

for variable in self.variables:
for key, value in variable.data.items():
try:
Expand All @@ -505,6 +510,7 @@ def _update_data(self, model_outputs: dict):
self._data[f"output_{key}"].append(value.cpu().detach()[0])
except KeyError:
self._data[f"output_{key}"] = [value.cpu().detach()[0]]


def run(self, dt: float, n_steps: int, record_every: int = None):
"""
Expand All @@ -527,7 +533,7 @@ def run(self, dt: float, n_steps: int, record_every: int = None):

def get_data(self):
"""Returns a dictionary of the recorded data"""
return {key: np.array(value) for key, value in self._data.items()}
return {key: torch.stack(value) for key, value in self._data.items()}

def reset_data(self):
"""Clear all recorded data"""
Expand Down

0 comments on commit 4638b00

Please sign in to comment.