Skip to content

Commit

Permalink
energy initialization (#66)
Browse files Browse the repository at this point in the history
* adapt energy initialization function to work when there is no E0 in model, rename said function for clarity

* fix mistake in documentation
  • Loading branch information
shinkle-lanl committed Apr 10, 2024
1 parent ba363b3 commit 4d7add5
Show file tree
Hide file tree
Showing 13 changed files with 55 additions and 47 deletions.
4 changes: 2 additions & 2 deletions docs/source/examples/minimal_workflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ Now we'll load a database::

Now that we have a database and a model, we can fit the non-interacting energies using the training set in the database::

from hippynn.pretraining import set_e0_values
set_e0_values(henergy,database,trainable_after=False)
from hippynn.pretraining import hierarchical_energy_initialization
hierarchical_energy_initialization(henergy,database,trainable_after=False)

We're almost there. We specify the training procedure with ``SetupParams``. We need to have

Expand Down
4 changes: 2 additions & 2 deletions examples/InPSNAPExample.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@

# Now that we have a database and a model, we can
# Fit the non-interacting energies by examining the database.
from hippynn.pretraining import set_e0_values
set_e0_values(henergy, database, peratom=True, energy_name="EnergyPerAtom", decay_factor=1e-2)
from hippynn.pretraining import hierarchical_energy_initialization
hierarchical_energy_initialization(henergy, database, peratom=True, energy_name="EnergyPerAtom", decay_factor=1e-2)
# Freeze sensitivity layers
for sense_layer in network.torch_module.sensitivity_layers:
sense_layer.mu.requires_grad_(False)
Expand Down
4 changes: 2 additions & 2 deletions examples/QM7_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@
# Now that we have a database and a model, we can
# Fit the non-interacting energies by examining the database.

from hippynn.pretraining import set_e0_values
from hippynn.pretraining import hierarchical_energy_initialization

set_e0_values(henergy, database, trainable_after=False)
hierarchical_energy_initialization(henergy, database, trainable_after=False)

min_epochs = 50
max_epochs = 800
Expand Down
4 changes: 2 additions & 2 deletions examples/TaSNAPExample.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def quantity_losses(quantity):

# Now that we have a database and a model, we can
# Fit the non-interacting energies by examining the database.
from hippynn.pretraining import set_e0_values
from hippynn.pretraining import hierarchical_energy_initialization

set_e0_values(henergy, database, peratom=True, energy_name="EnergyPerAtom", decay_factor=1e-2)
hierarchical_energy_initialization(henergy, database, peratom=True, energy_name="EnergyPerAtom", decay_factor=1e-2)
# Freeze sensitivity layers
for sense_layer in network.torch_module.sensitivity_layers:
sense_layer.mu.requires_grad_(False)
Expand Down
4 changes: 2 additions & 2 deletions examples/allegro_ag_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def fit_model(training_modules,database):

model, loss_module, model_evaluator = training_modules

from hippynn.pretraining import set_e0_values
set_e0_values(henergy, database, peratom=True, energy_name="energy_per_atom", decay_factor=1e-2)
from hippynn.pretraining import hierarchical_energy_initialization
hierarchical_energy_initialization(henergy, database, peratom=True, energy_name="energy_per_atom", decay_factor=1e-2)

from hippynn.experiment.controllers import RaiseBatchSizeOnPlateau, PatienceController

Expand Down
4 changes: 2 additions & 2 deletions examples/ani1x_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ def main(args):
seed=args.seed,
anidata_location=args.anidata_location)

from hippynn.pretraining import set_e0_values
from hippynn.pretraining import hierarchical_energy_initialization

set_e0_values(henergy, database, trainable_after=False)
hierarchical_energy_initialization(henergy, database, trainable_after=False)

setup_params = setup_experiment(training_modules,
device=args.gpu,
Expand Down
4 changes: 2 additions & 2 deletions examples/ani_aluminum_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@

# Now that we have a database and a model, we can
# Fit the non-interacting energies by examining the database.
from hippynn.pretraining import set_e0_values
from hippynn.pretraining import hierarchical_energy_initialization

set_e0_values(henergy, database, peratom=True, energy_name="energyperatom", decay_factor=1e-2)
hierarchical_energy_initialization(henergy, database, peratom=True, energy_name="energyperatom", decay_factor=1e-2)

from hippynn.experiment.controllers import RaiseBatchSizeOnPlateau, PatienceController

Expand Down
4 changes: 2 additions & 2 deletions examples/ani_aluminum_example_multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@

# Now that we have a database and a model, we can
# Fit the non-interacting energies by examining the database.
from hippynn.pretraining import set_e0_values
from hippynn.pretraining import hierarchical_energy_initialization

set_e0_values(henergy, database, peratom=True, energy_name="energyperatom", decay_factor=1e-2)
hierarchical_energy_initialization(henergy, database, peratom=True, energy_name="energyperatom", decay_factor=1e-2)

from hippynn.experiment.controllers import RaiseBatchSizeOnPlateau, PatienceController

Expand Down
4 changes: 2 additions & 2 deletions examples/barebones.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@
# Now that we have a database and a model, we can
# Fit the non-interacting energies by examining the database.
# This tends to stabilize training a lot.
from hippynn.pretraining import set_e0_values
from hippynn.pretraining import hierarchical_energy_initialization

set_e0_values(henergy, database, trainable_after=False)
hierarchical_energy_initialization(henergy, database, trainable_after=False)

# Parameters describing the training procedure.
from hippynn.experiment import setup_and_train
Expand Down
4 changes: 2 additions & 2 deletions examples/pyseqm/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@
del database.splits["ignore"]
database.make_trainvalidtest_split(test_size=0.1, valid_size=0.1)

# from hippynn.pretraining import set_e0_values
# set_e0_values(henergy,database,energy_name="T_transpose",trainable_after=False)
# from hippynn.pretraining import hierarchical_energy_initialization
# hierarchical_energy_initialization(henergy,database,energy_name="T_transpose",trainable_after=False)

init_lr = 1e-5
optimizer = torch.optim.Adam(training_modules.model.parameters(), lr=init_lr)
Expand Down
4 changes: 2 additions & 2 deletions examples/pyseqm/test_case2.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@

database = DirectoryDatabase(**database_params)

# from hippynn.pretraining import set_e0_values
# set_e0_values(henergy,database,energy_name="T_transpose",trainable_after=False)
# from hippynn.pretraining import hierarchical_energy_initialization
# hierarchical_energy_initialization(henergy,database,energy_name="T_transpose",trainable_after=False)

init_lr = 1e-5
optimizer = torch.optim.Adam(training_modules.model.parameters(), lr=init_lr)
Expand Down
4 changes: 2 additions & 2 deletions examples/singlet_triplet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def maghist(vec_prediction):

# Now that we have a database and a model, we can
# Fit the non-interacting energies by examining the database.
from hippynn.pretraining import set_e0_values
from hippynn.pretraining import hierarchical_energy_initialization

set_e0_values(henergy, database, energy_name="singlet_T", trainable_after=True)
hierarchical_energy_initialization(henergy, database, energy_name="singlet_T", trainable_after=True)

patience = 10
batch_size = 512
Expand Down
54 changes: 31 additions & 23 deletions hippynn/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Things to do before training, i.e. initialization of network and diagnostics.
"""

import warnings
import numpy as np
import torch

Expand All @@ -15,9 +16,9 @@
from .networks.hipnn import compute_hipnn_e0


def set_e0_values(
def hierarchical_energy_initialization(
energy_module,
database,
database=None,
trainable_after=False,
decay_factor=1e-2,
encoder=None,
Expand All @@ -29,12 +30,13 @@ def set_e0_values(
Computes values for the non-interacting energy using the training data.
:param energy_module: HEnergyNode or torch module for energy prediction
:param database: InterfaceDB object to get training data
:param trainable_after: Determines if it should change .requires_grad attribute for the E0 parameters.
:param database: InterfaceDB object to get training data, required if model contains E0 term
:param trainable_after: Determines if it should change .requires_grad attribute for the E0 parameters
:param decay_factor: change initialized weights of further energy layers by ``df**N`` for layer N
:param network_module: network for running the species encoding. Can be auto-identified from energy node
:param encoder: species encoder, can be auto-identified from energy node
:param energy_name: name for the energy variable, can be auto-identified from energy node
:param species_name: name for the species variable, can be auto-identified from energy node
:param peratom:
:return: None
"""

Expand All @@ -51,31 +53,37 @@ def set_e0_values(
if isinstance(encoder, _BaseNode):
encoder = encoder.torch_module

train_data = database.splits["train"]

z_vals = train_data[species_name]
t_vals = train_data[energy_name]

encoder.to(t_vals.device)
eovals = compute_hipnn_e0(encoder, z_vals, t_vals, peratom=peratom)
eo_layer = energy_module.layers[0]

if not eo_layer.weight.data.shape[-1] == eovals.shape[-1]:
raise NotImplementedError("The function set_eo_values does not currently work with custom InputNodes.")
# If model has E0 term, set its initial value using the database provided
if not energy_module.first_is_interacting:
if database is None:
raise ValueError("Database must be provided if model includes E0 energy term.")

eo_layer.weight.data = eovals.reshape(1,-1)
print("Computed E0 energies:", eovals)
print("Computed E0 energies:", eovals)
eo_layer.weight.data = eovals.expand_as(eo_layer.weight.data)
print("Computed E0 energies:", eovals)
eo_layer.weight.data = eovals.expand_as(eo_layer.weight.data)
train_data = database.splits["train"]

z_vals = train_data[species_name]
t_vals = train_data[energy_name]

encoder.to(t_vals.device)
eovals = compute_hipnn_e0(encoder, z_vals, t_vals, peratom=peratom)
eo_layer = energy_module.layers[0]

if not eo_layer.weight.data.shape[-1] == eovals.shape[-1]:
raise ValueError("The shape of the computed E0 values does not match the shape expected by the model.")

eo_layer.weight.data = eovals.reshape(1,-1)
print("Computed E0 energies:", eovals)
eo_layer.weight.data = eovals.expand_as(eo_layer.weight.data)
eo_layer.weight.requires_grad_(trainable_after)

eo_layer.weight.requires_grad_(trainable_after)
# Decay layers E1, E2, etc... according to decay_factor
for layer in energy_module.layers[1:]:
layer.weight.data *= decay_factor
layer.bias.data *= decay_factor
decay_factor *= decay_factor

def set_e0_values(*args, **kwargs):
warnings.warn("The function set_e0_values is depreciated. Please use the hierarchical_energy_initialization function instead.")
return hierarchical_energy_initialization(*args, **kwargs)

def _setup_min_dist_graph(
species_name,
Expand Down

0 comments on commit 4d7add5

Please sign in to comment.