From 4e84c36fc0b1502e829537c32492254cfeb80ef4 Mon Sep 17 00:00:00 2001 From: Sakib Matin <83463357+sakibmatin@users.noreply.github.com> Date: Tue, 3 Sep 2024 09:35:02 -0600 Subject: [PATCH 1/6] Check cuda device capability for triton kernel initialization. (#93) * check device compute for triton import * avoid nested try blocks * Clean up. * typo fixed --- hippynn/custom_kernels/__init__.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/hippynn/custom_kernels/__init__.py b/hippynn/custom_kernels/__init__.py index cd7ca9d3..8ef8953b 100644 --- a/hippynn/custom_kernels/__init__.py +++ b/hippynn/custom_kernels/__init__.py @@ -36,8 +36,18 @@ try: import triton + import torch + device_capability = torch.cuda.get_device_capability() + if device_capability[0] > 6: + CUSTOM_KERNELS_AVAILABLE.append("triton") + else: + warnings.warn( + f"Triton found but not supported by GPU's compute capability: {device_capability}" + ) +except ImportError: + pass - CUSTOM_KERNELS_AVAILABLE.append("triton") + except ImportError: pass @@ -76,7 +86,7 @@ def _check_cupy(): if not cupy.cuda.is_available(): if torch.cuda.is_available(): warnings.warn("cupy.cuda.is_available() returned False: Custom kernels will fail on GPU tensors.") - + def set_custom_kernels(active: Union[bool, str] = True): """ @@ -113,7 +123,6 @@ def set_custom_kernels(active: Union[bool, str] = True): return # Select custom kernel implementation - if not CUSTOM_KERNELS_AVAILABLE: raise RuntimeError("Numba was not found. Custom kernels are not available.") From 0004728799fa1048d87d785ea06cce59c609f041 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers <56895592+lubbersnick@users.noreply.github.com> Date: Fri, 6 Sep 2024 18:51:26 -0600 Subject: [PATCH 2/6] Add pytorch lightning trainer (#99) * initial attempt of lightning training interface * fix train_step and remove print * fix batch order for end validation epoch * fix types * fix raisebatchsize for lightning * remember to detach tensors * add valid tag to lr scheduler * add loss printing and controller * add dataloader args for additional configuration * refactor slightly and fix type errors * add extra dataloader args to test script * closer to connecting controller to lightning module * guard print statements * prevent double-updating of schedulers * add sanity checking guard * fix printing in sanity check * get batch size changes working with pytorch lightning * make sure custom kernels don't automatically trigger cuda context on device 0 * adding saving of modules (very necessary!) * make lightning trainer not have to serialize constantly * add coalescing custom kernel call for hip-nn-ts (l=2) * add coalescing custom kernel call to hip-nn-ts, l=1 * formating and debug print * update packages in docs * make lightning import optional * make metric tracker only seek better metrics on validation * Make controller and metric tracker see metrics reduced across nodes * update lightning test script * update docs and requirements, remove extraneous code * good old fashioned formatting --------- Co-authored-by: Nicholas Lubbers --- conda_requirements.txt | 1 + docs/source/conf.py | 4 +- docs/source/installation.rst | 17 +- docs/source/user_guide/settings.rst | 2 +- examples/barebones_lightning.py | 102 +++++ hippynn/_settings_setup.py | 30 +- hippynn/custom_kernels/tensor_wrapper.py | 8 +- hippynn/custom_kernels/test_env_numba.py | 7 + hippynn/databases/__init__.py | 9 + hippynn/databases/database.py | 38 +- hippynn/experiment/__init__.py | 9 +- hippynn/experiment/controllers.py | 56 ++- hippynn/experiment/lightning_trainer.py | 369 ++++++++++++++++++ hippynn/experiment/metric_tracker.py | 31 +- hippynn/experiment/routines.py | 5 +- hippynn/experiment/serialization.py | 14 +- hippynn/graphs/gops.py | 3 +- .../interfaces/ase_interface/ase_database.py | 12 +- hippynn/layers/hiplayers.py | 62 ++- hippynn/pretraining.py | 2 +- hippynn/tools.py | 19 +- setup.py | 1 + tests/lightning_QM7_test.py | 219 +++++++++++ 23 files changed, 912 insertions(+), 108 deletions(-) create mode 100644 examples/barebones_lightning.py create mode 100644 hippynn/experiment/lightning_trainer.py create mode 100644 tests/lightning_QM7_test.py diff --git a/conda_requirements.txt b/conda_requirements.txt index 590e5ca3..f8f1f391 100644 --- a/conda_requirements.txt +++ b/conda_requirements.txt @@ -8,3 +8,4 @@ ase h5py tqdm python-graphviz +lightning \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 9a71efed..a47dfe54 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,7 +19,7 @@ project = "hippynn" copyright = "2019, Los Alamos National Laboratory" -author = "Nicholas Lubbers" +author = "Nicholas Lubbers et al" # The full version, including alpha/beta/rc tags import hippynn @@ -47,7 +47,7 @@ } # The following are highly optional, so we mock them for doc purposes. -autodoc_mock_imports = ["pyanitools", "seqm", "schnetpack", "cupy", "lammps", "numba"] +autodoc_mock_imports = ["pyanitools", "seqm", "schnetpack", "cupy", "lammps", "numba", "triton", "pytorch_lightning"] # -- Options for HTML output ------------------------------------------------- diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 54384e44..4064fea9 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -10,16 +10,18 @@ Requirements: * Python_ >= 3.9 * pytorch_ >= 1.9 * numpy_ + Optional Dependencies: * triton_ (recommended, for improved GPU performance) * numba_ (recommended for improved CPU performance) - * cupy_ (Alternative for accelerating GPU performance) - * ASE_ (for usage with ase) + * cupy_ (alternative for accelerating GPU performance) + * ASE_ (for usage with ase and other misc. features) * matplotlib_ (for plotting) * tqdm_ (for progress bars) - * graphviz_ (for viewing model graphs as figures) + * graphviz_ (for visualizing model graphs) * h5py_ (for loading ani-h5 datasets) * pyanitools_ (for loading ani-h5 datasets) + * pytorch-lightning_ (for distributed training) Interfacing codes: * ASE_ @@ -40,7 +42,7 @@ Interfacing codes: .. _ASE: https://wiki.fysik.dtu.dk/ase/ .. _LAMMPS: https://www.lammps.org/ .. _PYSEQM: https://github.com/lanl/PYSEQM - +.. _pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning Installation Instructions ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -67,9 +69,6 @@ Clone the hippynn_ repository and navigate into it, e.g.:: .. _hippynn: https://github.com/lanl/hippynn/ -.. note:: - If you wish to do a cpu-only install, you may need to comment - out ``cupy`` from the conda_requirements.txt file. Dependencies using conda ........................ @@ -78,6 +77,10 @@ Install dependencies from conda using recommended channels:: $ conda install -c pytorch -c conda-forge --file conda_requirements.txt +.. note:: + If you wish to do a cpu-only install, you may need to comment + out ``cupy`` from the conda_requirements.txt file. + Dependencies using pip ....................... diff --git a/docs/source/user_guide/settings.rst b/docs/source/user_guide/settings.rst index d8657de4..c6764206 100644 --- a/docs/source/user_guide/settings.rst +++ b/docs/source/user_guide/settings.rst @@ -31,7 +31,7 @@ The following settings are available: - Dynamic * - PROGRESS - Progress bars function during training, evaluation, and prediction - - tqdm, none + - tqdm, none, or floating point string specifying default update rate in seconds (default 1). - tqdm - Yes, but assign this to a generator-wrapper such as ``tqdm.tqdm``, or with a python ``None`` to disable. The wrapper must accept ``tqdm`` arguments, although it technically doesn't have to do anything with them. * - DEFAULT_PLOT_FILETYPE diff --git a/examples/barebones_lightning.py b/examples/barebones_lightning.py new file mode 100644 index 00000000..4469d7ed --- /dev/null +++ b/examples/barebones_lightning.py @@ -0,0 +1,102 @@ +''' +To obtain the data files needed for this example, use the script process_QM7_data.py, +also located in this folder. The script contains further instructions for use. +''' + +import torch + +# Setup pytorch things +torch.set_default_dtype(torch.float32) + +import hippynn + +netname = "TEST_BAREBONES_LIGHTNING_SCRIPT" + +# Hyperparameters for the network +# These are set deliberately small so that you can easily run the example on a laptop or similar. +network_params = { + "possible_species": [0, 1, 6, 7, 8, 16], # Z values of the elements in QM7 + "n_features": 20, # Number of neurons at each layer + "n_sensitivities": 20, # Number of sensitivity functions in an interaction layer + "dist_soft_min": 1.6, # qm7 is in Bohr! + "dist_soft_max": 10.0, + "dist_hard_max": 12.5, + "n_interaction_layers": 2, # Number of interaction blocks + "n_atom_layers": 3, # Number of atom layers in an interaction block +} + +# Define a model +from hippynn.graphs import inputs, networks, targets, physics + +species = inputs.SpeciesNode(db_name="Z") +positions = inputs.PositionsNode(db_name="R") + +network = networks.Hipnn("hipnn_model", (species, positions), module_kwargs=network_params) +henergy = targets.HEnergyNode("HEnergy", network, db_name="T") +# hierarchicality = henergy.hierarchicality + +# define loss quantities +from hippynn.graphs import loss + +mse_energy = loss.MSELoss.of_node(henergy) +mae_energy = loss.MAELoss.of_node(henergy) +rmse_energy = mse_energy ** (1 / 2) + +# Validation losses are what we check on the data between epochs -- we can only train to +# a single loss, but we can check other metrics too to better understand how the model is training. +# There will also be plots of these things over time when training completes. +validation_losses = { + "RMSE": rmse_energy, + "MAE": mae_energy, + "MSE": mse_energy, +} + +# This piece of code glues the stuff together as a pytorch model, +# dropping things that are irrelevant for the losses defined. +training_modules, db_info = hippynn.experiment.assemble_for_training(mse_energy, validation_losses) + +# Go to a directory for the model. +# hippynn will save training files in the current working directory. +with hippynn.tools.active_directory(netname): + # Log the output of python to `training_log.txt` + with hippynn.tools.log_terminal("training_log.txt", "wt"): + database = hippynn.databases.DirectoryDatabase( + name="data-qm7", # Prefix for arrays in the directory + directory="../../../datasets/qm7_processed", + test_size=0.1, # Fraction or number of samples to test on + valid_size=0.1, # Fraction or number of samples to validate on + seed=2001, # Random seed for splitting data + **db_info, # Adds the inputs and targets db_names from the model as things to load + dataloader_kwargs=dict(persistent_workers=True,multiprocessing_context='fork'), + num_workers=2, + ) + + # Now that we have a database and a model, we can + # Fit the non-interacting energies by examining the database. + # This tends to stabilize training a lot. + from hippynn.pretraining import hierarchical_energy_initialization + + hierarchical_energy_initialization(henergy, database, trainable_after=False) + + # Parameters describing the training procedure. + from hippynn.experiment import setup_and_train + + experiment_params = hippynn.experiment.SetupParams( + stopping_key="MSE", # The name in the validation_losses dictionary. + batch_size=12, + optimizer=torch.optim.Adam, + max_epochs=100, + learning_rate=0.001, + ) + # setup_and_train( + # training_modules=training_modules, + # database=database, + # setup_params=experiment_params, + # ) + from hippynn.experiment import HippynnLightningModule + +# lightning needs to run exactly where the script is located in distributed modes. +lightmod, datamodule = HippynnLightningModule.from_experiment_setup(training_modules, database, experiment_params) +import pytorch_lightning as pl +trainer = pl.Trainer(accelerator='cpu') #'auto' detects MPS which doesn't work. +trainer.fit(model=lightmod, datamodule=datamodule) diff --git a/hippynn/_settings_setup.py b/hippynn/_settings_setup.py index 62b872c9..c72eb94f 100644 --- a/hippynn/_settings_setup.py +++ b/hippynn/_settings_setup.py @@ -29,16 +29,22 @@ TQDM_PROGRESS = None if TQDM_PROGRESS is not None: - TQDM_PROGRESS = partial(TQDM_PROGRESS, mininterval=1.0, leave=False) - + DEFAULT_PROGRESS = partial(TQDM_PROGRESS, mininterval=1.0, leave=False) +else: + DEFAULT_PROGRESS = None ### Progress handlers - def progress_handler(prog_str): if prog_str == "tqdm": - return TQDM_PROGRESS - if prog_str.lower() == "none": + return DEFAULT_PROGRESS + elif prog_str.lower() == "none": return None + else: + try: + prog_float = float(prog_str) + return partial(TQDM_PROGRESS, mininterval=prog_float, leave=False) + except: + pass warnings.warn(f"Unrecognized progress setting: '{prog_str}'. Setting to none.") @@ -63,7 +69,7 @@ def kernel_handler(kernel_string): # keys: defaults, types, and handlers default_settings = { - "PROGRESS": (TQDM_PROGRESS, progress_handler), + "PROGRESS": (DEFAULT_PROGRESS, progress_handler), "DEFAULT_PLOT_FILETYPE": (".pdf", str), "TRANSPARENT_PLOT": (False, strtobool), "DEBUG_LOSS_BROADCAST": (False, strtobool), @@ -85,11 +91,16 @@ def kernel_handler(kernel_string): config_sources = {} # Dictionary of configuration variable sources mapping to dictionary of configuration. # We add to this dictionary in order of application +SECTION_NAME = "GLOBALS" + rc_name = os.path.expanduser("~/.hippynnrc") if os.path.exists(rc_name) and os.path.isfile(rc_name): config = configparser.ConfigParser(inline_comment_prefixes="#") config.read(rc_name) - config_sources["~/.hippynnrc"] = config["GLOBALS"] + if SECTION_NAME not in config: + warnings.warn(f"Config file {rc_name} does not contain a {SECTION_NAME} section and will be ignored!") + else: + config_sources["~/.hippynnrc"] = config[SECTION_NAME] SETTING_PREFIX = "HIPPYNN_" hippynn_environment_variables = { @@ -103,7 +114,10 @@ def kernel_handler(kernel_string): if os.path.exists(local_rc_fname) and os.path.isfile(local_rc_fname): local_config = configparser.ConfigParser() local_config.read(local_rc_fname) - config_sources[LOCAL_RC_FILE_KEY] = local_config["GLOBALS"] + if SECTION_NAME not in local_config: + warnings.warn(f"Config file {local_rc_fname} does not contain a {SECTION_NAME} section and will be ignored!") + else: + config_sources[LOCAL_RC_FILE_KEY] = local_config[SECTION_NAME] else: warnings.warn(f"Local configuration file {local_rc_fname} not found.") diff --git a/hippynn/custom_kernels/tensor_wrapper.py b/hippynn/custom_kernels/tensor_wrapper.py index ade8ddcf..6323b61e 100644 --- a/hippynn/custom_kernels/tensor_wrapper.py +++ b/hippynn/custom_kernels/tensor_wrapper.py @@ -38,8 +38,8 @@ def _numba_gpu_not_found(*args, **kwargs): class NumbaCompatibleTensorFunction: def __init__(self): if numba.cuda.is_available(): - self.kernel64 = self.make_kernel(numba.float64) - self.kernel32 = self.make_kernel(numba.float32) + self.kernel64 = None + self.kernel32 = None else: self.kernel64 = _numba_gpu_not_found self.kernel32 = _numba_gpu_not_found @@ -59,8 +59,12 @@ def __call__(self, *args, **kwargs): with numba.cuda.gpus[dev.index]: numba_args = batch_convert_torch_to_numba(*args) if dtype == torch.float64: + if self.kernel64 is None: + self.kernel64 = self.make_kernel(numba.float64) self.kernel64[launch_bounds](*numba_args) elif dtype == torch.float32: + if self.kernel32 is None: + self.kernel32 = self.make_kernel(numba.float32) self.kernel32[launch_bounds](*numba_args) else: raise ValueError("Bad dtype: {}".format(dtype)) diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env_numba.py index d9a117c1..616a2eb8 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env_numba.py @@ -122,6 +122,7 @@ def get_simulated_data(n_molecules, n_atoms, atom_prob, n_features, n_nu, printi TEST_LARGE_PARAMS = dict(n_molecules=1000, n_atoms=30, atom_prob=0.7, n_features=80, n_nu=20) TEST_MEGA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=100) TEST_ULTRA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=320) +TEST_GIGA_PARAMS = dict(n_molecules=32, n_atoms=30, atom_prob=0.7, n_features=512, n_nu=320) # reference implementation @@ -434,6 +435,12 @@ def main(env_impl, sense_impl, feat_impl, args=None): if use_verylarge_gpu: if use_ultra: + + print("-" * 80) + print("Giga systems:", TEST_GIGA_PARAMS) + tester.check_speed( + n_repetitions=20, data_size=TEST_GIGA_PARAMS, device=torch.device("cuda"), compare_against=compare_against + ) print("-" * 80) print("Ultra systems:", TEST_ULTRA_PARAMS) tester.check_speed( diff --git a/hippynn/databases/__init__.py b/hippynn/databases/__init__.py index d938fd54..e97ad715 100644 --- a/hippynn/databases/__init__.py +++ b/hippynn/databases/__init__.py @@ -12,16 +12,25 @@ from .database import Database from .ondisk import DirectoryDatabase, NPZDatabase has_ase = False +has_h5 = False + try: import ase has_ase = True + import h5py + has_h5 = True except ImportError: pass if has_ase: from ..interfaces.ase_interface import AseDatabase + if has_h5: + from .h5_pyanitools import PyAniFileDB, PyAniDirectoryDB all_list = ["Database", "DirectoryDatabase", "NPZDatabase"] + if has_ase: all_list += ["AseDatabase"] + if has_h5: + all_list += ["PyAniFileDB", "PyAniDirectoryDB"] __all__ = all_list diff --git a/hippynn/databases/database.py b/hippynn/databases/database.py index 19acbb52..fa503763 100644 --- a/hippynn/databases/database.py +++ b/hippynn/databases/database.py @@ -1,6 +1,8 @@ """ Base database functionality from dictionary of numpy arrays """ + +from typing import Union import warnings import numpy as np import torch @@ -20,17 +22,18 @@ class Database: def __init__( self, - arr_dict, - inputs, - targets, - seed, - test_size=None, - valid_size=None, - num_workers=0, - pin_memory=True, - allow_unfound=False, - auto_split=False, - device=None, + arr_dict: dict[str,torch.Tensor], + inputs: list[str], + targets: list[str], + seed: [int,np.random.RandomState,tuple], + test_size: Union[float,int]=None, + valid_size: Union[float,int]=None, + num_workers: int=0, + pin_memory: bool=True, + allow_unfound:bool =False, + auto_split:bool =False, + device: torch.device=None, + dataloader_kwargs:dict[str,object]=None, quiet=False, ): """ @@ -47,6 +50,9 @@ def __init__( :param allow_unfound: If true, skip checking if the needed inputs and targets are found. This allows setting inputs=None and/or targets=None. :param auto_split: If true, look for keys like "split_*" to make initial splits from. See write_npz() method. + :param device: if set, move the dataset to this device after splitting. + :param dataloader_kwargs: dictionary, passed to pytorch dataloaders in addition to num_workers, pin_memory. + Refer to pytorch documentation for details. :param quiet: If True, print little or nothing while loading. """ @@ -123,6 +129,8 @@ def __init__( else: self.send_to_device(device) + self.dataloader_kwargs = dataloader_kwargs.copy() if dataloader_kwargs else {} + def __len__(self): return arrdict_len(self.arr_dict) @@ -425,6 +433,7 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample shuffle=shuffle, pin_memory=self.pin_memory, num_workers=self.num_workers, + **self.dataloader_kwargs, ) return generator @@ -514,7 +523,7 @@ def write_h5(self, split=None, h5path=None, species_key='species', overwrite=Fal return write_h5_function(self, split=split, file=h5path, species_key=species_key, overwrite=overwrite) - def write_npz(self, file: str, record_split_masks: bool = True, overwrite: bool = False, split_prefix=None, return_only=False): + def write_npz(self, file: str, record_split_masks: bool = True, compressed:bool =True, overwrite: bool = False, split_prefix=None, return_only=False): """ :param file: str, Path, or file object compatible with np.save :param record_split_masks: @@ -561,7 +570,10 @@ def write_npz(self, file: str, record_split_masks: bool = True, overwrite: bool if file.exists() and not overwrite: raise FileExistsError(f"File exists: {file}") - np.savez_compressed(file, **arr_dict) + if compressed: + np.savez_compressed(file, **arr_dict) + else: + np.savez(file, **arr_dict) return arr_dict diff --git a/hippynn/experiment/__init__.py b/hippynn/experiment/__init__.py index e31cb597..3a222e9b 100644 --- a/hippynn/experiment/__init__.py +++ b/hippynn/experiment/__init__.py @@ -13,4 +13,11 @@ from .assembly import assemble_for_training from .routines import setup_and_train, setup_training, train_model, test_model, SetupParams -__all__ = ["assemble_for_training", "setup_and_train", "setup_training", "train_model", "test_model", "SetupParams"] + +__all__ = ["assemble_for_training", "setup_and_train", "setup_training", "train_model", "test_model", "SetupParams",] + +try: + from .lightning_trainer import HippynnLightningModule + __all__ += ["HippynnLightningModule"] +except ImportError: + pass diff --git a/hippynn/experiment/controllers.py b/hippynn/experiment/controllers.py index 125e168a..dc88e77f 100644 --- a/hippynn/experiment/controllers.py +++ b/hippynn/experiment/controllers.py @@ -6,7 +6,6 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau - class Controller: """ Class for controlling the training dynamics. @@ -51,12 +50,10 @@ def __init__( fraction_train_eval=0.1, quiet=False, ): + super().__init__() self.optimizer = optimizer - self.scheduler = scheduler - self.stopping_key = stopping_key - self.batch_size = batch_size self.eval_batch_size = eval_batch_size or batch_size if max_epochs is None: @@ -85,7 +82,8 @@ def __init__( def state_dict(self): state_dict = {k: getattr(self, k) for k in self._state_vars} - state_dict["optimizer"] = self.optimizer.state_dict() + if self.optimizer is not None: + state_dict["optimizer"] = self.optimizer.state_dict() state_dict["scheduler"] = [sch.state_dict() for sch in self.scheduler_list] return state_dict @@ -94,7 +92,8 @@ def load_state_dict(self, state_dict): for sch, sdict in zip(self.scheduler_list, state_dict["scheduler"]): sch.load_state_dict(sdict) - self.optimizer.load_state_dict(state_dict["optimizer"]) + if self.optimizer is not None: + self.optimizer.load_state_dict(state_dict["optimizer"]) for k in self._state_vars: setattr(self, k, state_dict[k]) @@ -103,7 +102,7 @@ def load_state_dict(self, state_dict): def max_epochs(self): return self._max_epochs - def push_epoch(self, epoch, better_model, metric): + def push_epoch(self, epoch, better_model, metric, _print=print): self.current_epoch += 1 if better_model: @@ -118,8 +117,9 @@ def push_epoch(self, epoch, better_model, metric): sch.step() if not self.quiet: - print("Epochs since last best:", self.boredom) - print("Current max epochs:", self.max_epochs) + _print("Epochs since last best:", self.boredom) + _print("Current max epochs:", self.max_epochs) + return self.current_epoch < self.max_epochs @@ -139,23 +139,27 @@ def __init__(self, *args, termination_patience, **kwargs): self.patience = termination_patience self.last_best = 0 - def push_epoch(self, epoch, better_model, metric): + def push_epoch(self, epoch, better_model, metric, _print=print): if better_model: if self.boredom > 0 and not self.quiet: - print("Patience for training restored.") + _print("Patience for training restored.") self.boredom = 0 self.last_best = epoch - return super().push_epoch(epoch, better_model, metric) + return super().push_epoch(epoch, better_model, metric, _print=_print) @property def max_epochs(self): - return min(self.last_best + self.patience, self._max_epochs) + return min(self.last_best + self.patience + 1, self._max_epochs) -class RaiseBatchSizeOnPlateau: +# Developer note: The inheritance here is only so that pytorch lightning +# readily identifies this as a scheduler. +class RaiseBatchSizeOnPlateau(ReduceLROnPlateau): """ Learning rate scheduler compatible with pytorch schedulers. + Note: The "VERBOSE" Parameter has been deprecated and no longer does anything. + This roughly implements the scheme outlined in the following paper: .. code-block:: none @@ -182,9 +186,20 @@ def __init__( patience=10, threshold=0.0001, threshold_mode="rel", - verbose=True, + verbose=None, # DEPRECATED controller=None, ): + """ + + :param optimizer: + :param max_batch_size: + :param factor: + :param patience: + :param threshold: + :param threshold_mode: + :param verbose: + :param controller: + """ if threshold_mode not in ("abs", "rel"): raise ValueError("Mode must be 'abs' or 'rel'") @@ -195,13 +210,17 @@ def __init__( factor=factor, threshold=threshold, threshold_mode=threshold_mode, - verbose=verbose, ) self.controller = controller self.max_batch_size = max_batch_size self.best_metric = float("inf") self.boredom = 0 self.last_epoch = 0 + warnings.warn("Parameter verbose no longer supported for schedulers. It will be ignored.") + + @property + def optimizer(self): + return self.inner.optimizer def set_controller(self, box): self.controller = box @@ -250,12 +269,9 @@ def step(self, metrics): new_batch_size = min(new_batch_size, self.max_batch_size) self.controller.batch_size = new_batch_size self.boredom = 0 - if self.inner.verbose: - print("Raising batch size to", new_batch_size) + if new_batch_size >= self.max_batch_size: self.inner.last_epoch = self.last_epoch - 1 - if self.inner.verbose: - print("Max batch size reached, Lowering learning rate from here.") return diff --git a/hippynn/experiment/lightning_trainer.py b/hippynn/experiment/lightning_trainer.py new file mode 100644 index 00000000..ead8eb57 --- /dev/null +++ b/hippynn/experiment/lightning_trainer.py @@ -0,0 +1,369 @@ +""" +Pytorch Lightning training interface. + +This module is somewhat experimental. Using pytorch lightning +successfully in a distributed context may require understanding +and adjusting the various settings related to parallelism, e.g. +multiprocessing context, torch ddp backend, and how they interact +with your HPC environment. + +Some features of hippynn experiments may not be implemented yet. + - The plotmaker is currently not supported. + +""" +import warnings +import copy +from pathlib import Path + +import torch + +import pytorch_lightning as pl + +from .routines import TrainingModules +from ..databases import Database +from .routines import SetupParams, setup_training +from ..graphs import GraphModule +from .controllers import Controller +from .metric_tracker import MetricTracker +from .step_functions import get_step_function, StandardStep +from ..tools import print_lr +from . import serialization + + +class HippynnLightningModule(pl.LightningModule): + def __init__( + self, + model: GraphModule, + loss: GraphModule, + eval_loss: GraphModule, + eval_names: list[str], + stopping_key: str, + optimizer_list: list[torch.optim.Optimizer], + scheduler_list: list[torch.optim.lr_scheduler], + controller: Controller, + metric_tracker: MetricTracker, + inputs: list[str], + targets: list[str], + n_outputs: int, + *args, + **kwargs, + ): # forwards args and kwargs to where? + super().__init__() + + self.save_hyperparameters(ignore=["loss", "model", "eval_loss", "controller", "optimizer_list", "scheduler_list"]) + + self.model = model + self.loss = loss + self.eval_loss = eval_loss + self.eval_names = eval_names + self.stopping_key = stopping_key + self.controller = controller + self.metric_tracker = metric_tracker + self.optimizer_list = optimizer_list + self.scheduler_list = scheduler_list + self.inputs = inputs + self.targets = targets + self.n_inputs = len(self.inputs) + self.n_targets = len(self.targets) + self.n_outputs = n_outputs + + self.structure_file = None + + self._last_reload_dlene = None # storage for whether batch size should be changed. + + # Storage for predictions across batches for eval mode. + self.eval_step_outputs = [] + self.controller.optimizer = None + + for optimizer in self.optimizer_list: + if not isinstance(step_fn := get_step_function(optimizer), StandardStep): # := + raise NotImplementedError(f"Optimzers with non-standard steps are not yet supported. {optimizer,step_fn}") + + if args or kwargs: + raise NotImplementedError("Generic args and kwargs not supported.") + + @classmethod + def from_experiment_setup(cls, training_modules: TrainingModules, database: Database, setup_params: SetupParams, **kwargs): + training_modules, controller, metric_tracker = setup_training(training_modules, setup_params) + return cls.from_train_setup(training_modules, database, controller, metric_tracker, **kwargs) + + @classmethod + def from_train_setup( + cls, + training_modules: TrainingModules, + database: Database, + controller: Controller, + metric_tracker: MetricTracker, + callbacks=None, + batch_callbacks=None, + **kwargs, + ): + + model, loss, evaluator = training_modules + + warnings.warn("PytorchLightning hippynn trainer is still experimental.") + + if evaluator.plot_maker is not None: + warnings.warn("plot_maker is not currently supported in pytorch lightning. The current plot_maker will be ignored.") + + trainer = cls( + model=model, + loss=loss, + eval_loss=evaluator.loss, + eval_names=evaluator.loss_names, + optimizer_list=[controller.optimizer], + scheduler_list=controller.scheduler_list, + stopping_key=controller.stopping_key, + controller=controller, + metric_tracker=metric_tracker, + inputs=database.inputs, + targets=database.targets, + n_outputs=evaluator.n_outputs, + **kwargs, + ) + + # pytorch lightning is now in charge of stepping the scheduler. + controller.scheduler_list = [] + + if callbacks is not None or batch_callbacks is not None: + return NotImplemented("arbitrary callbacks are not yet supported with pytorch lightning.") + + return trainer, HippynnDataModule(database, controller.batch_size) + + def on_save_checkpoint(self, checkpoint) -> None: + + # Note to future developers: + # trainer.log_dir property needs to be called on all ranks! This is weird but important; + # do not move trainer.log_dir inside of a rank zero operation! + # see https://github.com/Lightning-AI/pytorch-lightning/discussions/8321 + # Thank you to https://github.com/semaphore-egg . + log_dir = self.trainer.log_dir + + if not self.structure_file: + # Perform change on all ranks. + sf = serialization.DEFAULT_STRUCTURE_FNAME + self.structure_file = sf + + if self.global_rank == 0 and not self.structure_file: + self.print("creating structure file.") + structure = dict( + model=self.model, + loss=self.loss, + eval_loss=self.eval_loss, + controller=self.controller, + optimizer_list=self.optimizer_list, + scheduler_list=self.scheduler_list, + ) + path: Path = Path(log_dir).joinpath(sf) + self.print("Saving structure file at", path) + torch.save(obj=structure, f=path) + + checkpoint["controller_state"] = self.controller.state_dict() + return + + @classmethod + def load_from_checkpoint(cls, checkpoint_path, map_location=None, structure_file=None, hparams_file=None, strict=True, **kwargs): + + if structure_file is None: + # Assume checkpoint_path is like /version_/checkpoints/.chkpt + # and that experiment file is stored at /version_/experiment_structure.pt + structure_file = Path(checkpoint_path) + structure_file = structure_file.parent.parent + structure_file = structure_file.joinpath(serialization.DEFAULT_STRUCTURE_FNAME) + + structure_args = torch.load(structure_file) + + return super().load_from_checkpoint( + checkpoint_path, map_location=map_location, hparams_file=hparams_file, strict=strict, **structure_args, **kwargs + ) + + def on_load_checkpoint(self, checkpoint) -> None: + cstate = checkpoint.pop("controller_state") + self.controller.load_state_dict(cstate) + return + + def configure_optimizers(self): + + scheduler_list = [] + for s in self.scheduler_list: + config = { + "scheduler": s, + "interval": "epoch", # can be epoch or step + "frequency": 1, # How many intervals should pass between calls to `scheduler.step()`. + "monitor": "valid_" + self.stopping_key, # Metric to monitor for schedulers like `ReduceLROnPlateau` + "strict": True, + "name": "learning_rate", + } + scheduler_list.append(config) + + optimizer_list = self.optimizer_list.copy() + + return optimizer_list, scheduler_list + + def on_train_epoch_start(self): + for optimizer in self.optimizer_list: + print_lr(optimizer, print_=self.print) + self.print("Batch size:", self.trainer.train_dataloader.batch_size) + + def training_step(self, batch, batch_idx): + + batch_inputs = batch[: self.n_inputs] + batch_targets = batch[-self.n_targets :] + + batch_model_outputs = self.model(*batch_inputs) + batch_train_loss = self.loss(*batch_model_outputs, *batch_targets)[0] + + self.log("train_loss", batch_train_loss) + return batch_train_loss + + def _eval_step(self, batch, batch_idx): + + batch_inputs = batch[: self.n_inputs] + batch_targets = batch[-self.n_targets :] + + # It is very, very common to fit to derivatives, e.g. force, in hippynn. Override lightning default. + with torch.autograd.set_grad_enabled(True): + batch_predictions = self.model(*batch_inputs) + + batch_predictions = [bp.detach() for bp in batch_predictions] + + outputs = (batch_predictions, batch_targets) + self.eval_step_outputs.append(outputs) + return batch_predictions + + def validation_step(self, batch, batch_idx): + return self._eval_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self._eval_step(batch, batch_idx) + + def _eval_epoch_end(self, prefix): + + all_batch_predictions, all_batch_targets = zip(*self.eval_step_outputs) + # now 'shape' (n_batch, n_outputs) -> need to transpose. + all_batch_predictions = [[bpred[i] for bpred in all_batch_predictions] for i in range(self.n_outputs)] + # now 'shape' (n_batch, n_targets) -> need to transpose. + all_batch_targets = [[bpred[i] for bpred in all_batch_targets] for i in range(self.n_targets)] + + # now cat each prediction and target across the batch index. + all_predictions = [torch.cat(x, dim=0) if x[0].shape != () else x[0] for x in all_batch_predictions] + all_targets = [torch.cat(x, dim=0) for x in all_batch_targets] + + all_losses = [x.item() for x in self.eval_loss(*all_predictions, *all_targets)] + self.eval_step_outputs.clear() # free memory + + loss_dict = {name: value for name, value in zip(self.eval_names, all_losses)} + + self.log_dict({prefix + k: v for k, v in loss_dict.items()}, sync_dist=True) + + return + + def on_validation_epoch_end(self): + self._eval_epoch_end(prefix="valid_") + return + + def on_test_epoch_end(self): + self._eval_epoch_end(prefix="test_") + return + + def _eval_end(self, prefix, when=None) -> None: + if when is None: + if self.trainer.sanity_checking: + when = "Sanity Check" + else: + when = self.current_epoch + + # Step 1: get metrics reduced from all ranks. + # Copied pattern from pytorch_lightning. + metrics = copy.deepcopy(self.trainer.callback_metrics) + + pre_len = len(prefix) + loss_dict = {k[pre_len:]: v.item() for k, v in metrics.items() if k.startswith(prefix)} + + loss_dict = {prefix[:-1]: loss_dict} # strip underscore from prefix and wrap. + + if self.trainer.sanity_checking: + self.print("Sanity check metric values:") + self.metric_tracker.evaluation_print(loss_dict, _print=self.print) + return + + # Step 2: register metrics + out_ = self.metric_tracker.register_metrics(loss_dict, when=when) + better_metrics, better_model, stopping_metric = out_ + self.metric_tracker.evaluation_print_better(loss_dict, better_metrics, _print=self.print) + + continue_training = self.controller.push_epoch(self.current_epoch, better_model, stopping_metric, _print=self.print) + + if not continue_training: + self.print("Controller is terminating training.") + self.trainer.should_stop = True + + # Step 3: Logic for changing the batch size without always requiring new dataloaders. + # Step 3a: don't do this when not testing. + if not self.trainer.training: + return + + controller_batch_size = self.controller.batch_size + trainer_batch_size = self.trainer.train_dataloader.batch_size + if controller_batch_size != trainer_batch_size: + # Need to trigger a batch size change. + if self._last_reload_dlene is None: + # save the original value of this variable to the pl module + self._last_reload_dlene = self.trainer.reload_dataloaders_every_n_epochs + + # TODO: Make this run even if there isn't an explicit datamodule? + self.trainer.datamodule.batch_size = controller_batch_size + # Tell PL lightning to reload the dataloaders now. + self.trainer.reload_dataloaders_every_n_epochs = 1 + + elif self._last_reload_dlene is not None: + # Restore the last saved value from the pl module. + self.trainer.reload_dataloaders_every_n_epochs = self._last_reload_dlene + self._last_reload_dlene = None + else: + # Batch sizes match, and there's no variable to restore. + pass + return + + def on_validation_end(self): + self._eval_end(prefix="valid_") + return + + def on_test_end(self): + self._eval_end(prefix="test_", when="test") + return + + +class LightingPrintStagesCallback(pl.Callback): + """ + This callback is for debugging only. + It prints whenever a callback stage is entered in pytorch lightning. + """ + + for k in dir(pl.Callback): + if k.startswith("on_"): + + def some_method(self, *args, _k=k, **kwargs): + all_args = kwargs.copy() + all_args.update({i: a for i, a in enumerate(args)}) + int_args = {k: v for k, v in all_args.items() if isinstance(v, int)} + print("Callback stage:", _k, "with integer arguments:", int_args) + + exec(f"{k} = some_method") + del some_method + + +class HippynnDataModule(pl.LightningDataModule): + def __init__(self, database: Database, batch_size): + super().__init__() + self.database = database + self.batch_size = batch_size + + def train_dataloader(self): + return self.database.make_generator("train", "train", self.batch_size) + + def val_dataloader(self): + return self.database.make_generator("valid", "eval", self.batch_size) + + def test_dataloader(self): + return self.database.make_generator("test", "eval", self.batch_size) diff --git a/hippynn/experiment/metric_tracker.py b/hippynn/experiment/metric_tracker.py index f43426e6..af28d7fc 100644 --- a/hippynn/experiment/metric_tracker.py +++ b/hippynn/experiment/metric_tracker.py @@ -85,7 +85,6 @@ def register_metrics(self, metric_info, when): except KeyError: if split_type not in self.best_metric_values: # Haven't seen this split before! - print("ADDING ",split_type) self.best_metric_values[split_type] = {} better_metrics[split_type] = {} better = True # old best was not found! @@ -99,7 +98,7 @@ def register_metrics(self, metric_info, when): else: self.other_metric_values[when] = metric_info - if self.stopping_key: + if self.stopping_key and "valid" in metric_info: better_model = better_metrics.get("valid", {}).get(self.stopping_key, False) stopping_key_metric = metric_info["valid"][self.stopping_key] else: @@ -108,21 +107,21 @@ def register_metrics(self, metric_info, when): return better_metrics, better_model, stopping_key_metric - def evaluation_print(self, evaluation_dict, quiet=None): + def evaluation_print(self, evaluation_dict, quiet=None, _print=print): if quiet is None: quiet = self.quiet if quiet: return - table_evaluation_print(evaluation_dict, self.metric_names, self.name_column_width) + table_evaluation_print(evaluation_dict, self.metric_names, self.name_column_width, _print=_print) - def evaluation_print_better(self, evaluation_dict, better_dict, quiet=None): + def evaluation_print_better(self, evaluation_dict, better_dict, quiet=None, _print=print): if quiet is None: quiet = self.quiet if quiet: return - table_evaluation_print_better(evaluation_dict, better_dict, self.metric_names, self.name_column_width) + table_evaluation_print_better(evaluation_dict, better_dict, self.metric_names, self.name_column_width, _print=print) if self.stopping_key: - print( + _print( "Best {} so far: {:>8.5g}".format( self.stopping_key, self.best_metric_values["valid"][self.stopping_key] ) @@ -134,7 +133,7 @@ def plot_over_time(self): # Driver for printing evaluation table results, with * for better entries. # Decoupled from the estate in case we want to more easily change print formatting. -def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, n_columns): +def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, n_columns, _print=print): """ Print metric results as a table, add a '*' character for metrics in better_dict. @@ -157,16 +156,16 @@ def table_evaluation_print_better(evaluation_dict, better_dict, metric_names, n_ header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names) rowstring = "{:<" + str(n_columns) + "}: " + " {}{:>10.5g}" * n_types - print(header) - print("-" * len(header)) + _print(header) + _print("-" * len(header)) for n, valsbet in zip(metric_names, transposed_values_better): rowoutput = [k for bv in valsbet for k in bv] - print(rowstring.format(n, *rowoutput)) + _print(rowstring.format(n, *rowoutput)) # Driver for printing evaluation table results. # Decoupled from the estate in case we want to more easily change print formatting. -def table_evaluation_print(evaluation_dict, metric_names, n_columns): +def table_evaluation_print(evaluation_dict, metric_names, n_columns, _print=print): """ Print metric results as a table. @@ -184,8 +183,8 @@ def table_evaluation_print(evaluation_dict, metric_names, n_columns): header = " " * (n_columns + 2) + "".join("{:>14}".format(tn) for tn in type_names) rowstring = "{:<" + str(n_columns) + "}: " + " {:>10.5g}" * n_types - print(header) - print("-" * len(header)) + _print(header) + _print("-" * len(header)) for n, vals in zip(metric_names, transposed_values): - print(rowstring.format(n, *vals)) - print("-" * len(header)) + _print(rowstring.format(n, *vals)) + _print("-" * len(header)) diff --git a/hippynn/experiment/routines.py b/hippynn/experiment/routines.py index f6aee191..84faa5c0 100644 --- a/hippynn/experiment/routines.py +++ b/hippynn/experiment/routines.py @@ -306,9 +306,7 @@ def train_model( print("Finishing up...") print("Training phase ended.") - if store_metrics: - with open("training_metrics.pkl", "wb") as pfile: - pickle.dump(metric_tracker, pfile) + torch.save(metric_tracker, "training_metrics.pt") best_model = metric_tracker.best_model if best_model: @@ -448,6 +446,7 @@ def training_loop( qprint("_" * 50) qprint("Epoch {}:".format(epoch)) tools.print_lr(optimizer) + qprint("Batch Size:", controller.batch_size) qprint(flush=True, end="") diff --git a/hippynn/experiment/serialization.py b/hippynn/experiment/serialization.py index c4d73c1a..326812fa 100644 --- a/hippynn/experiment/serialization.py +++ b/hippynn/experiment/serialization.py @@ -1,5 +1,7 @@ """ -checkpoint and state generation +Checkpoint and state generation. + +As a user, in most cases you will only need the `load` functions here. """ from typing import Tuple, Union @@ -12,7 +14,7 @@ from ..graphs import GraphModule from ..tools import device_fallback from .assembly import TrainingModules -from .controllers import PatienceController +from .controllers import Controller from .device import set_devices from .metric_tracker import MetricTracker @@ -21,13 +23,13 @@ def create_state( model: GraphModule, - controller: PatienceController, + controller: Controller, metric_tracker: MetricTracker, ) -> dict: """Create an experiment state dictionary. :param model: current model - :param controller: patience controller + :param controller: controller :param metric_tracker: current metrics :return: dictionary containing experiment state. :rtype: dict @@ -43,7 +45,7 @@ def create_state( def create_structure_file( training_modules: TrainingModules, database: Database, - controller: PatienceController, + controller: Controller, fname=DEFAULT_STRUCTURE_FNAME, ) -> None: """ @@ -51,7 +53,7 @@ def create_structure_file( :param training_modules: contains model, controller, and loss :param database: database for training - :param controller: patience controller + :param controller: controller :param fname: filename to save the checkpoint :return: None diff --git a/hippynn/graphs/gops.py b/hippynn/graphs/gops.py index ee98ddc1..01fc6682 100644 --- a/hippynn/graphs/gops.py +++ b/hippynn/graphs/gops.py @@ -50,7 +50,8 @@ def compute_evaluation_order(all_nodes): evaluation_inputs_list = [] evaluation_outputs_list = [] - unsatisfied_nodes = all_nodes.copy() + # need to sort to get stable results between runs/processes. + unsatisfied_nodes = list(sorted(all_nodes, key=lambda node: node.name)) satisfied_nodes = set() n = -1 while len(unsatisfied_nodes) > 0: diff --git a/hippynn/interfaces/ase_interface/ase_database.py b/hippynn/interfaces/ase_interface/ase_database.py index b3c05057..992e16f4 100644 --- a/hippynn/interfaces/ase_interface/ase_database.py +++ b/hippynn/interfaces/ase_interface/ase_database.py @@ -24,14 +24,14 @@ import os import numpy as np -from ase.io import read +from ase.io import read, iread -from ...tools import np_of_torchdefaultdtype +from ...tools import np_of_torchdefaultdtype, progress_bar from ...databases.database import Database from ...databases.restarter import Restartable from typing import Union from typing import List - +import hippynn.tools class AseDatabase(Database, Restartable): """ @@ -84,11 +84,11 @@ def load_arrays(self, directory, filename, inputs, targets, quiet=False, allow_u var_list = inputs + targets try: if isinstance(filename, str): - db = read(directory + filename, index=":") + db = list(progress_bar(iread(directory+filename,index=":"), desc='configs'))#read(directory + filename, index=":") elif isinstance(filename, (list, np.ndarray)): db = [] - for name in filename: - temp_db = read(directory + name, index=":") + for name in progress_bar(filename, desc='files'): + temp_db = list(progress_bar(iread(directory + name, index=":"), desc='configs')) db += temp_db except FileNotFoundError as fee: raise FileNotFoundError( diff --git a/hippynn/layers/hiplayers.py b/hippynn/layers/hiplayers.py index 2f62c07d..b93aae60 100644 --- a/hippynn/layers/hiplayers.py +++ b/hippynn/layers/hiplayers.py @@ -275,16 +275,26 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) n_atoms_real = in_features.shape[0] sense_vals = self.sensitivity(dist_pairs) + # Sensitivity stacking + sense_vec = sense_vals.unsqueeze(1) * (coord_pairs / dist_pairs.unsqueeze(1)).unsqueeze(2) + sense_vec = sense_vec.reshape(-1, self.n_dist * 3) + sense_stacked = torch.concatenate([sense_vals, sense_vec], dim=1) + + # Message passing, stack sensitivities to coalesce custom kernel call. + # shape (n_atoms, n_nu + 3*n_nu, n_feat) + env_features_stacked = custom_kernels.envsum(sense_stacked, in_features, pair_first, pair_second) + # shape (n_atoms, 4, n_nu, n_feat) + env_features_stacked = env_features_stacked.reshape(-1, 4, self.n_dist, self.nf_in) + + # separate to tensor components + env_features, env_features_vec = torch.split(env_features_stacked, [1, 3], dim=1) + # Scalar part - env_features = custom_kernels.envsum(sense_vals, in_features, pair_first, pair_second) env_features = torch.reshape(env_features, (n_atoms_real, self.n_dist * self.nf_in)) weights_rs = torch.reshape(self.int_weights.permute(0, 2, 1), (self.n_dist * self.nf_in, self.nf_out)) features_out = torch.mm(env_features, weights_rs) # Vector part - sense_vec = sense_vals.unsqueeze(1) * (coord_pairs / dist_pairs.unsqueeze(1)).unsqueeze(2) - sense_vec = sense_vec.reshape(-1, self.n_dist * 3) - env_features_vec = custom_kernels.envsum(sense_vec, in_features, pair_first, pair_second) env_features_vec = env_features_vec.reshape(n_atoms_real * 3, self.n_dist * self.nf_in) features_out_vec = torch.mm(env_features_vec, weights_rs) features_out_vec = features_out_vec.reshape(n_atoms_real, 3, self.nf_out) @@ -315,19 +325,41 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) n_atoms_real = in_features.shape[0] sense_vals = self.sensitivity(dist_pairs) - # Scalar part - env_features = custom_kernels.envsum(sense_vals, in_features, pair_first, pair_second) + #### + # Sensitivity calculations + # scalar: sense_vals + # vector: sense_vec + # quadrupole: sense_quad + rhats = coord_pairs / dist_pairs.unsqueeze(1) + sense_vec = sense_vals.unsqueeze(1) * rhats.unsqueeze(2) + sense_vec = sense_vec.reshape(-1, self.n_dist * 3) + rhatsquad = rhats.unsqueeze(1) * rhats.unsqueeze(2) + rhatsquad = (rhatsquad + rhatsquad.transpose(1, 2)) / 2 + tr = torch.diagonal(rhatsquad, dim1=1, dim2=2).sum(dim=1) / 3.0 # Add divide by 3 early to save flops + tr = tr.unsqueeze(1).unsqueeze(2) * torch.eye(3, dtype=tr.dtype, device=tr.device).unsqueeze(0) + rhatsquad = rhatsquad - tr + rhatsqflat = rhatsquad.reshape(-1, 9)[:, self.upper_ind] # Upper-diagonal part + sense_quad = sense_vals.unsqueeze(1) * rhatsqflat.unsqueeze(2) + sense_quad = sense_quad.reshape(-1, self.n_dist * 5) + sense_stacked = torch.concatenate([sense_vals, sense_vec, sense_quad], dim=1) + + # Message passing, stack sensitivities to coalesce custom kernel call. + # shape (n_atoms, n_nu + 3*n_nu + 5*n_nu, n_feat) + env_features_stacked = custom_kernels.envsum(sense_stacked, in_features, pair_first, pair_second) + # shape (n_atoms, 9, n_nu, n_feat) + env_features_stacked = env_features_stacked.reshape(-1, 9, self.n_dist, self.nf_in) + + # separate to tensor components + env_features, env_features_vec, env_features_quad = torch.split(env_features_stacked, [1, 3, 5], dim=1) + + # Scalar stuff. env_features = torch.reshape(env_features, (n_atoms_real, self.n_dist * self.nf_in)) weights_rs = torch.reshape(self.int_weights.permute(0, 2, 1), (self.n_dist * self.nf_in, self.nf_out)) features_out = torch.mm(env_features, weights_rs) # Vector part # Sensitivity - rhats = coord_pairs / dist_pairs.unsqueeze(1) - sense_vec = sense_vals.unsqueeze(1) * rhats.unsqueeze(2) - sense_vec = sense_vec.reshape(-1, self.n_dist * 3) # Weights - env_features_vec = custom_kernels.envsum(sense_vec, in_features, pair_first, pair_second) env_features_vec = env_features_vec.reshape(n_atoms_real * 3, self.n_dist * self.nf_in) features_out_vec = torch.mm(env_features_vec, weights_rs) # Norm and scale @@ -338,16 +370,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) # Quadrupole part # Sensitivity - rhatsquad = rhats.unsqueeze(1) * rhats.unsqueeze(2) - rhatsquad = (rhatsquad + rhatsquad.transpose(1, 2)) / 2 - tr = torch.diagonal(rhatsquad, dim1=1, dim2=2).sum(dim=1) / 3.0 # Add divide by 3 early to save flops - tr = tr.unsqueeze(1).unsqueeze(2) * torch.eye(3, dtype=tr.dtype, device=tr.device).unsqueeze(0) - rhatsquad = rhatsquad - tr - rhatsqflat = rhatsquad.reshape(-1, 9)[:, self.upper_ind] # Upper-diagonal part - sense_quad = sense_vals.unsqueeze(1) * rhatsqflat.unsqueeze(2) - sense_quad = sense_quad.reshape(-1, self.n_dist * 5) # Weights - env_features_quad = custom_kernels.envsum(sense_quad, in_features, pair_first, pair_second) env_features_quad = env_features_quad.reshape(n_atoms_real * 5, self.n_dist * self.nf_in) features_out_quad = torch.mm(env_features_quad, weights_rs) ##sum v b features_out_quad = features_out_quad.reshape(n_atoms_real, 5, self.nf_out) @@ -359,6 +382,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) # Scales features_out_quad = features_out_quad * self.quadscales.unsqueeze(0) + # Combine features_out_selfpart = self.selfint(in_features) features_out_total = features_out + features_out_vec + features_out_quad + features_out_selfpart diff --git a/hippynn/pretraining.py b/hippynn/pretraining.py index e039905a..726186fc 100644 --- a/hippynn/pretraining.py +++ b/hippynn/pretraining.py @@ -70,7 +70,7 @@ def hierarchical_energy_initialization( if not eo_layer.weight.data.shape[-1] == eovals.shape[-1]: raise ValueError("The shape of the computed E0 values does not match the shape expected by the model.") - eo_layer.weight.data = eovals.reshape(1,-1) + eo_layer.weight.data = eovals.reshape(1, -1) print("Computed E0 energies:", eovals) eo_layer.weight.data = eovals.expand_as(eo_layer.weight.data) eo_layer.weight.requires_grad_(trainable_after) diff --git a/hippynn/tools.py b/hippynn/tools.py index d2c78133..df1507cb 100644 --- a/hippynn/tools.py +++ b/hippynn/tools.py @@ -133,9 +133,9 @@ def arrdict_len(array_dictionary): return len(next(iter(array_dictionary.values()))) -def print_lr(optimizer): +def print_lr(optimizer, print_=print): for i, param_group in enumerate(optimizer.param_groups): - print("Learning rate:{:>10.5g}".format(param_group["lr"])) + print_("Learning rate:{:>10.5g}".format(param_group["lr"])) def isiterable(obj): @@ -217,3 +217,18 @@ def is_equal_state_dict(d1, d2, raise_where=False): return True +def recursive_param_count(state_dict, n=0): + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + n += v.numel() + elif isinstance(v, dict): + n += recursive_param_count(v) + elif isinstance(v, (list, tuple)): + n += recursive_param_count({i: x for i, x in enumerate(v)}) + elif isinstance(v, (float, int)): + n += 1 + elif v is None: + pass + else: + raise TypeError(f'Unknown type {type(v)=}, value={v}') + return n diff --git a/setup.py b/setup.py index 3f0100a4..95d1333e 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ "tqdm", "graphviz", "h5py", + "lightning", ] setuptools.setup( diff --git a/tests/lightning_QM7_test.py b/tests/lightning_QM7_test.py new file mode 100644 index 00000000..ac1d0b0f --- /dev/null +++ b/tests/lightning_QM7_test.py @@ -0,0 +1,219 @@ +""" + +This is a test script based on /examples/QM7_example.py which uses pytorch lightning to train. + +""" + +PERFORM_PLOTTING = True # Make sure you have matplotlib if you want to set this to TRUE + +#### Setup pytorch things +import torch + +torch.set_default_dtype(torch.float32) + +if torch.cuda.is_available(): + torch.cuda.set_device(0) # Don't try this if you want CPU training! + +import hippynn + + +def main(): + hippynn.settings.WARN_LOW_DISTANCES = False + + # Note: these settings may need to be adjusted depending on the platform where + # this code is run. + n_devices = 2 + num_workers = 0 + multiprocessing_context = "fork" + + # Hyperparameters for the network + netname = "TEST_LIGHTNING_MODEL" + network_params = { + "possible_species": [0, 1, 6, 7, 8, 16], # Z values of the elements + "n_features": 20, # Number of neurons at each layer + "n_sensitivities": 20, # Number of sensitivity functions in an interaction layer + "dist_soft_min": 1.6, # + "dist_soft_max": 10.0, + "dist_hard_max": 12.5, + "n_interaction_layers": 2, # Number of interaction blocks + "n_atom_layers": 3, # Number of atom layers in an interaction block + } + + # Define a model + + from hippynn.graphs import inputs, networks, targets, physics + + # model inputs + species = inputs.SpeciesNode(db_name="Z") + positions = inputs.PositionsNode(db_name="R") + + # Model computations + network = networks.HipnnVec("HIPNN", (species, positions), module_kwargs=network_params) + henergy = targets.HEnergyNode("HEnergy", network) + molecule_energy = henergy.mol_energy + molecule_energy.db_name = "T" + hierarchicality = henergy.hierarchicality + + # define loss quantities + from hippynn.graphs import loss + + rmse_energy = loss.MSELoss.of_node(molecule_energy) ** (1 / 2) + mae_energy = loss.MAELoss.of_node(molecule_energy) + rsq_energy = loss.Rsq.of_node(molecule_energy) + + ### More advanced usage of loss graph + + pred_per_atom = physics.PerAtom("PeratomPredicted", (molecule_energy, species)).pred + true_per_atom = physics.PerAtom("PeratomTrue", (molecule_energy.true, species.true)) + mae_per_atom = loss.MAELoss(pred_per_atom, true_per_atom) + + ### End more advanced usage of loss graph + + loss_error = rmse_energy + mae_energy + + rbar = loss.Mean.of_node(hierarchicality) + l2_reg = loss.l2reg(network) + loss_regularization = 1e-6 * l2_reg + rbar # L2 regularization and hierarchicality regularization + + train_loss = loss_error + loss_regularization + + # Validation losses are what we check on the data between epochs -- we can only train to + # a single loss, but we can check other metrics too to better understand how the model is training. + # There will also be plots of these things over time when training completes. + validation_losses = { + "T-RMSE": rmse_energy, + "T-MAE": mae_energy, + "T-RSQ": rsq_energy, + "TperAtom MAE": mae_per_atom, + "T-Hier": rbar, + "L2Reg": l2_reg, + "Loss-Err": loss_error, + "Loss-Reg": loss_regularization, + "Loss": train_loss, + } + early_stopping_key = "Loss-Err" + + if PERFORM_PLOTTING: + + from hippynn import plotting + + plot_maker = plotting.PlotMaker( + # Simple plots which compare the network to the database + plotting.Hist2D.compare(molecule_energy, saved=True), + # Slightly more advanced control of plotting! + plotting.Hist2D( + true_per_atom, + pred_per_atom, + xlabel="True Energy/Atom", + ylabel="Predicted Energy/Atom", + saved="PerAtomEn.pdf", + ), + plotting.HierarchicalityPlot(hierarchicality.pred, molecule_energy.pred - molecule_energy.true, saved="HierPlot.pdf"), + plot_every=10, # How often to make plots -- here, epoch 0, 10, 20... + ) + else: + plot_maker = None + + from hippynn.experiment import assemble_for_training + + # This piece of code glues the stuff together as a pytorch model, + # dropping things that are irrelevant for the losses defined. + training_modules, db_info = assemble_for_training(train_loss, validation_losses, plot_maker=plot_maker) + training_modules[0].print_structure() + + if num_workers > 0: + dataloader_kwargs = dict(multiprocessing_context=multiprocessing_context, persistent_workers=True) + else: + dataloader_kwargs = None + database_params = { + "name": "qm7", # Prefix for arrays in folder + "directory": "../../datasets/qm7_processed", + "quiet": False, + "test_size": 0.1, + "valid_size": 0.1, + "seed": 2001, + # How many samples from the training set to use during evaluation + **db_info, # Adds the inputs and targets names from the model as things to load + "dataloader_kwargs": dataloader_kwargs, + "num_workers": num_workers, + } + + from hippynn.databases import DirectoryDatabase + + database = DirectoryDatabase(**database_params) + + # Now that we have a database and a model, we can + # Fit the non-interacting energies by examining the database. + + from hippynn.pretraining import hierarchical_energy_initialization + + hierarchical_energy_initialization(henergy, database, trainable_after=False) + + from hippynn.experiment.controllers import PatienceController + from torch.optim.lr_scheduler import ReduceLROnPlateau + + optimizer = torch.optim.Adam(training_modules.model.parameters(), lr=1e-3) + + scheduler = ReduceLROnPlateau( + optimizer=optimizer, + factor=0.5, + patience=1, + ) + + controller = PatienceController( + optimizer=optimizer, + scheduler=scheduler, + batch_size=16, # start batch size + eval_batch_size=16, + max_epochs=3, + termination_patience=10, + fraction_train_eval=0.1, + stopping_key=early_stopping_key, + ) + + experiment_params = hippynn.experiment.SetupParams( + controller=controller, + ) + + from hippynn.experiment import HippynnLightningModule + + lightmod, datamodule = HippynnLightningModule.from_experiment_setup(training_modules, database, experiment_params) + import pytorch_lightning as pl + from pytorch_lightning.loggers import CSVLogger + + logger = CSVLogger(save_dir=".", name=netname, flush_logs_every_n_steps=100) + from pytorch_lightning.callbacks import ModelCheckpoint + + checkpointer = ModelCheckpoint( + monitor=f"valid_{early_stopping_key}", + save_last=True, + save_top_k=5, + every_n_epochs=1, + every_n_train_steps=None, + ) + + from hippynn.experiment.lightning_trainer import LightingPrintStagesCallback + + cb = LightingPrintStagesCallback() # include this callback if you aren't sure what stage of lightning is broken. + + # The default accelerator, 'auto' detects MPS on mac. hippynn doesn't work on MPS (yet). + # So we set cpu here. + trainer = pl.Trainer( + accelerator="cpu", + logger=logger, + num_nodes=1, + devices=n_devices, + callbacks=[checkpointer], + log_every_n_steps=1, + max_epochs=-1, # This is set this way because the hippynn controller should terminate training. + ) + + trainer.fit( + model=lightmod, + datamodule=datamodule, + ) + trainer.test(datamodule=datamodule, ckpt_path="best") + + +if __name__ == "__main__": + main() From 10a9055e29635ba412739ba810e27cfe0bb1c66e Mon Sep 17 00:00:00 2001 From: Emily Shinkle Date: Mon, 9 Sep 2024 11:05:42 -0600 Subject: [PATCH 3/6] Add coarse-graining example (#98) * first draft * track unwrapped and wrapped positions in MD code when cell is present, fix typo * remove unused imports, update md length * add link to data on Zenodo, change dataset filename --- examples/coarse-graining/README.rst | 7 + examples/coarse-graining/cg_md.py | 115 +++++++++++++ examples/coarse-graining/cg_training.py | 162 ++++++++++++++++++ .../coarse-graining/repulsive_potential.py | 58 +++++++ hippynn/experiment/serialization.py | 2 +- hippynn/molecular_dynamics/md.py | 14 +- 6 files changed, 351 insertions(+), 7 deletions(-) create mode 100644 examples/coarse-graining/README.rst create mode 100644 examples/coarse-graining/cg_md.py create mode 100644 examples/coarse-graining/cg_training.py create mode 100644 examples/coarse-graining/repulsive_potential.py diff --git a/examples/coarse-graining/README.rst b/examples/coarse-graining/README.rst new file mode 100644 index 00000000..9175cb3e --- /dev/null +++ b/examples/coarse-graining/README.rst @@ -0,0 +1,7 @@ +The files in this directory allow one to train and run MD with a coarse-grained HIPNN model. Details of this model can be found in the paper "Thermodynamic Transferability in Coarse-Grained Force Fields using Graph Neural Networks" by Shinkle et. al. available at . + +Before executing these files, one must download the training data from . The file should be placed at `datasets/cg_methanol_trajectory.npz` where `datasets/` is at the same level as the hippynn repository. + +1. Run `cg_training.py` to generate a model. This model will be saved in `hippynn/examples/coarse-graining/model`. +2. Run `cg_md.py` to run MD using the model trained in step 1. The resulting trajectory will be saved in `hippynn/examples/coarse-graining/md_results`. + diff --git a/examples/coarse-graining/cg_md.py b/examples/coarse-graining/cg_md.py new file mode 100644 index 00000000..4ffc7885 --- /dev/null +++ b/examples/coarse-graining/cg_md.py @@ -0,0 +1,115 @@ +import os + +import numpy as np +import torch + +from ase import units + +from hippynn.experiment.serialization import load_checkpoint_from_cwd +from hippynn.graphs.predictor import Predictor +from hippynn.molecular_dynamics.md import ( + Variable, + NullUpdater, + LangevinDynamics, + MolecularDynamics, +) +from hippynn.tools import active_directory + +default_dtype=torch.float +torch.set_default_dtype(default_dtype) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Load initial conditions +training_data_file = os.path.join(os.pardir,os.pardir,os.pardir,"datasets","cg_methanol_trajectory.npz") + +with np.load(training_data_file) as data: + cell = torch.as_tensor(data["cells"][-1], dtype=default_dtype, device=device)[None,...] + masses = torch.as_tensor(data["masses"][-1], dtype=default_dtype, device=device)[None,...] + positions = torch.as_tensor(data["positions"][-1], dtype=default_dtype, device=device)[None,...] + velocities = torch.as_tensor(data["velocities"][-1], dtype=default_dtype, device=device)[None,...] + species = torch.as_tensor(data["species"][-1], dtype=torch.int, device=device)[None,...] + +positions_variable = Variable( + name="positions", + data={ + "position": positions, + "velocity": velocities, + "mass": masses, + "acceleration": torch.zeros_like(velocities), + "cell": cell, + }, + model_input_map={"positions": "position"}, + device=device, +) + +position_updater = LangevinDynamics( + force_db_name="forces", + temperature=700, + frix=6, + units_force=units.kcal / units.mol / units.Ang, + units_acc=units.Ang / ((1000 * units.fs)**2), + seed=1993, +) +positions_variable.updater = position_updater + +cell_variable = Variable( + name="cell", + data={"cell": cell}, + model_input_map={"cells": "cell"}, + device=device, + updater=NullUpdater(), +) + +species_variable = Variable( + name="species", + data={"species": species}, + model_input_map={"species": "species"}, + device=device, + updater=NullUpdater(), +) + +# Load model +with active_directory("model"): + check = load_checkpoint_from_cwd(model_device=device, restart_db=False) + +repulse = check["training_modules"].model.node_from_name("repulse") +energy = check["training_modules"].model.node_from_name("sys_energy") + +model = Predictor.from_graph( + check["training_modules"].model, + additional_outputs=[ + repulse.mol_energies, + energy, + ], +) + +model = Predictor.from_graph(check["training_modules"].model) + +model.to(default_dtype) +model.to(device) + +pairs = model.graph.node_from_name("pairs") +pairs.skin = 3 # see hippynn.graphs.nodes.pairs.KDTreePairsMemory documentation + +# Run MD +with active_directory("md_results"): + emdee = MolecularDynamics( + variables=[positions_variable, species_variable, cell_variable], + model=model, + ) + + emdee.run(dt=0.001, n_steps=20000) + emdee.run(dt=0.001, n_steps=50000, record_every=50) + + data = emdee.get_data() + np.savez("hippynn_cg_trajectory.npz", + positions = data["positions_position"], + velocities = data["positions_velocity"], + masses = data["positions_mass"], + accelerations = data["positions_acceleration"], + cells = data["positions_cell"], + unwrapped_positions = data["positions_unwrapped_position"], + forces = data["positions_force"], + species = data["species_species"], + ) \ No newline at end of file diff --git a/examples/coarse-graining/cg_training.py b/examples/coarse-graining/cg_training.py new file mode 100644 index 00000000..699e93fb --- /dev/null +++ b/examples/coarse-graining/cg_training.py @@ -0,0 +1,162 @@ +import os + +import numpy as np +import torch + +from hippynn.databases import NPZDatabase +from hippynn.experiment import SetupParams, setup_and_train +from hippynn.experiment.assembly import assemble_for_training +from hippynn.experiment.controllers import RaiseBatchSizeOnPlateau, PatienceController +from hippynn.graphs import IdxType +from hippynn.graphs.nodes import loss +from hippynn.graphs.nodes.base.algebra import AddNode +from hippynn.graphs.nodes.indexers import acquire_encoding_padding +from hippynn.graphs.nodes.inputs import SpeciesNode, PositionsNode, CellNode +from hippynn.graphs.nodes.networks import HipnnQuad +from hippynn.graphs.nodes.pairs import KDTreePairsMemory +from hippynn.graphs.nodes.physics import MultiGradientNode +from hippynn.graphs.nodes.targets import HEnergyNode +from hippynn.plotting import PlotMaker, Hist2D, SensitivityPlot +from hippynn.tools import active_directory + +from repulsive_potential import RepulsivePotentialNode + +training_data_file = os.path.join(os.pardir,os.pardir,os.pardir,"datasets","cg_methanol_trajectory.npz") + +with np.load(training_data_file) as data: + idx = np.where(data["rdf_values"] > 0.01)[0][0] + repulsive_potential_taper_point = data["rdf_bins"][idx] + repulsive_potential_strength = np.abs(data["forces"]).mean() + +## Initialize needed nodes for network +# Network input nodes +species = SpeciesNode(name="species", db_name="species") +positions = PositionsNode(name="positions", db_name="positions") +cells = CellNode(name="cells", db_name="cells") + +# Network hyperparameters +network_params = { + "possible_species": [0,1], + "n_features": 128, + "n_sensitivities": 20, + "dist_soft_min": 2.0, + "dist_soft_max": 13.0, + "dist_hard_max": 15.0, + "n_interaction_layers": 1, + "n_atom_layers": 3, + "sensitivity_type": "inverse", + "resnet": True, +} + +# Species encoder +enc, pdx = acquire_encoding_padding([species], species_set=[0,1]) + +# Pair finder +pair_finder = KDTreePairsMemory( + "pairs", + (positions, enc, pdx, cells), + dist_hard_max=network_params["dist_hard_max"], + skin=0, +) + +# HIP-NN-TS node with l=2 +network = HipnnQuad( + "HIPNN", (pdx, pair_finder), module_kwargs=network_params, periodic=True +) + +# Network energy prediction +henergy = HEnergyNode("HEnergy", parents=(network,)) + +# Repulsive potential +repulse = RepulsivePotentialNode( + "repulse", + (pair_finder, pdx), + taper_point=repulsive_potential_taper_point, + strength=repulsive_potential_strength, + dr=0.15, + perc=0.05, +) + +# Combined energy prediction +energy = AddNode(henergy.main_output, repulse.mol_energies) +energy.name = "energies" +energy._index_state = IdxType.Molecules + +sys_energy = energy.main_output +sys_energy.name = "sys_energy" + +# Force node +grad = MultiGradientNode("forces", energy, (positions,), signs=-1) +force = grad.children[0] +force.db_name = "forces" + +## Define losses +force_rsq = loss.Rsq.of_node(force) +force_rmse = loss.MSELoss.of_node(force) ** (1 / 2) +force_mae = loss.MAELoss.of_node(force) +total_loss = force_rmse + force_mae + +validation_losses = { + "ForceRMSE": force_rmse, + "ForceMAE": force_mae, + "ForceRsq": force_rsq, + "TotalLoss": total_loss, +} + +plotters = [ + Hist2D.compare(force, saved="forces", shown=False), + SensitivityPlot( + network.torch_module.sensitivity_layers[0], saved="sensitivity", shown=False + ), +] + +plot_maker = PlotMaker( + *plotters, + plot_every=10, +) + +## Build network +training_modules, db_info = assemble_for_training( + total_loss, validation_losses, plot_maker=plot_maker +) + +## Load training data +database = NPZDatabase( + training_data_file, + seed=0, + **db_info, + valid_size=0.1, + test_size=0.1, +) + +## Set up optimizer +optimizer = torch.optim.Adam(training_modules.model.parameters(), lr=1e-3) + +scheduler = RaiseBatchSizeOnPlateau( + optimizer=optimizer, + max_batch_size=64, + patience=10, + factor=0.5, +) + +controller = PatienceController( + optimizer=optimizer, + scheduler=scheduler, + batch_size=1, + fraction_train_eval=0.2, + eval_batch_size=1, + max_epochs=200, + termination_patience=20, + stopping_key="TotalLoss", +) + +experiment_params = SetupParams(controller=controller) + +## Train! +with active_directory("model"): + metric_tracker = setup_and_train( + training_modules=training_modules, + database=database, + setup_params=experiment_params, + ) + \ No newline at end of file diff --git a/examples/coarse-graining/repulsive_potential.py b/examples/coarse-graining/repulsive_potential.py new file mode 100644 index 00000000..6a72d323 --- /dev/null +++ b/examples/coarse-graining/repulsive_potential.py @@ -0,0 +1,58 @@ +import math +import torch + +from hippynn.graphs import IdxType +from hippynn.graphs.nodes.base import ExpandParents +from hippynn.graphs.nodes.base.definition_helpers import AutoKw +from hippynn.graphs.nodes.base.multi import MultiNode +from hippynn.graphs.nodes.tags import PairIndexer, AtomIndexer +from hippynn.layers import pairs + +## Define repulsive potential node for hippynn +class RepulsivePotential(torch.nn.Module): + def __init__(self, taper_point, strength, dr, perc): + ''' + Let F(r) be the force between two particles of distance r generated + by this potential. Then + F(taper_point) = perc * strength + F(taper_point - dr) = strength + + Eg. If taper_point=3, strength=1, dr=0.5, and perc=0.01, then + F(3) = 0.01 + F(2.5) = 1 + ''' + super().__init__() + self.t = taper_point + self.s = strength + self.d = dr + self.p = perc + + self.a = (1/self.d)*math.log(1/self.p) + self.g = -1 * self.s * self.p * math.exp(self.a * self.t) / self.a + + self.summer = pairs.MolPairSummer() + + def forward(self, pair_dist, pair_first, mol_index, n_molecules): + atom_energies = -1 * self.g * torch.exp(-1 * self.a * pair_dist) + mol_energies = self.summer(atom_energies, mol_index, n_molecules, pair_first) + return mol_energies, atom_energies, + +class RepulsivePotentialNode(ExpandParents, AutoKw, MultiNode): + _input_names = "pair_dist", "pair_first", "mol_index", "n_molecules" + _output_names = "mol_energies", "atom_energies", + _auto_module_class = RepulsivePotential + _output_index_states = IdxType.Molecules, IdxType.Pair, + + @_parent_expander.match(PairIndexer, AtomIndexer) + def expansion(self, pairfinder, pidxer, **kwargs): + return pairfinder.pair_dist, pairfinder.pair_first, pidxer.mol_index, pidxer.n_molecules + + def __init__(self, name, parents, taper_point, strength, dr, perc, module="auto"): + self.module_kwargs = { + "taper_point": taper_point, + "strength": strength, + "dr": dr, + "perc": perc, + } + parents = self.expand_parents(parents, module="auto") + super().__init__(name, parents, module=module) \ No newline at end of file diff --git a/hippynn/experiment/serialization.py b/hippynn/experiment/serialization.py index 326812fa..57f9574c 100644 --- a/hippynn/experiment/serialization.py +++ b/hippynn/experiment/serialization.py @@ -140,7 +140,7 @@ def load_checkpoint( :param structure_fname: name of the structure file :param state_fname: name of the state file - :param restart_db: restore database or not, defaults to True + :param restart_db: restore database or not, defaults to False :param map_location: device mapping argument for ``torch.load``, defaults to None :param model_device: automatically handle device mapping. Defaults to None, defaults to None :return: experiment structure diff --git a/hippynn/molecular_dynamics/md.py b/hippynn/molecular_dynamics/md.py index b16cc552..8752990c 100644 --- a/hippynn/molecular_dynamics/md.py +++ b/hippynn/molecular_dynamics/md.py @@ -1,5 +1,6 @@ from __future__ import annotations from functools import singledispatchmethod +from copy import copy import numpy as np import torch @@ -9,7 +10,6 @@ from ..graphs import Predictor from ..layers.pairs.periodic import wrap_systems_torch - class Variable: """ Tracks the state of a quantity (eg. position, cell, species, @@ -307,11 +307,12 @@ def pre_step(self, dt): self.variable.data["position"] = self.variable.data["position"] + self.variable.data["velocity"] * dt - try: - _, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only used for discarded outputs; can be set arbitrarily - except KeyError: - pass - + if "cell" in self.variable.data.keys(): + _, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only impacts unused outputs; can be set arbitrarily + try: + self.variable.data["unwrapped_position"] = self.variable.data["unwrapped_position"] + self.variable.data["velocity"] * dt + except KeyError: + self.variable.data["unwrapped_position"] = copy(self.variable.data["position"]) def post_step(self, dt, model_outputs): """Updates to variables performed during each step of MD simulation after HIPNN model evaluation @@ -406,6 +407,7 @@ def model(self, model): + f" Entries in the 'model_input_map' should have the form 'hipnn-db_name: variable-data-key' where 'hipnn-db_name'" + f" refers to the db_name of an input for the hippynn Predictor model," + f" and 'variable-data-key' corresponds to a key in the 'data' dictionary of one of the Variables." + + f" Currently assigned db_names are: {variable_data_db_names}." ) self._model = model From be79f3a85ef2ebc7425d5d9a15eb3902ed469b50 Mon Sep 17 00:00:00 2001 From: Emily Shinkle Date: Mon, 9 Sep 2024 11:13:26 -0600 Subject: [PATCH 4/6] change setup to include pyproject.toml file (#97) * change setup to include pyproject.toml file * update github actions to use python -m build * add build dependency intallation * get versioneer with pyproject.toml support --- .github/workflows/build.yml | 4 ++-- .github/workflows/deploy.yml | 4 ++-- pyproject.toml | 40 +++++++++++++++++++++++++++++++++++ setup.py | 41 +----------------------------------- 4 files changed, 45 insertions(+), 44 deletions(-) create mode 100644 pyproject.toml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e8abe0c7..a71debfe 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,7 +18,7 @@ jobs: - name: Install dependencies run: >- - python -m pip install --user --upgrade setuptools wheel + python -m pip install --user --upgrade setuptools wheel build - name: Build run: >- - python setup.py sdist bdist_wheel + python -m build diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index e4044c6a..de34b743 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -27,10 +27,10 @@ jobs: - name: Install dependencies run: >- - python -m pip install --user --upgrade setuptools wheel + python -m pip install --user --upgrade setuptools wheel build - name: Build run: >- - python setup.py sdist bdist_wheel + python -m build - name: Publish distribution 📦 to PyPI if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..c95a1da9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[build-system] +requires=["setuptools>=64", "versioneer[toml]"] +build-backend="setuptools.build_meta" + +[project] +name="hippynn" +dynamic=["version"] +authors=[{name="Nicholas Lubbers et al",email="hippynn@lanl.gov"}] +description="The hippynn python package - a modular library for atomistic machine learning with pytorch" +requires-python=">=3.9" +license={text="BSD 3-Clause License"} +classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Physics", + "Topic :: Scientific/Engineering :: Chemistry", + "Topic :: Software Development :: Libraries", +] +readme="README.rst" +dependencies=[ + "numpy", + "torch", +] + +[project.optional-dependencies] +docs=[ + "sphinx", + "sphinx_rtd_theme", + "ase", +] +full=[ + "ase", + "numba", + "matplotlib", + "tqdm", + "graphviz", + "h5py", + "lightning", +] \ No newline at end of file diff --git a/setup.py b/setup.py index 95d1333e..9592002d 100644 --- a/setup.py +++ b/setup.py @@ -1,47 +1,8 @@ import setuptools import versioneer -with open("README.rst", "r") as fh: - long_description = fh.read() - -doc_requirements = [ - "sphinx", - "sphinx_rtd_theme", - "ase", -] - -full_requirements = [ - "ase", - "numba", - "matplotlib", - "tqdm", - "graphviz", - "h5py", - "lightning", -] - setuptools.setup( - name="hippynn", version=versioneer.get_version(), - author="Nicholas Lubbers et al", - author_email="hippynn@lanl.gov", - python_requires=">=3.9", - install_requires=[ - "numpy", - "torch", - ], - extras_require={"docs": doc_requirements, "full": full_requirements}, - license="BSD 3-Clause License", - classifiers=[ - "Development Status :: 3 - Alpha", - "Intended Audience :: Science/Research", - "Programming Language :: Python :: 3", - "Topic :: Scientific/Engineering :: Physics", - "Topic :: Scientific/Engineering :: Chemistry", - "Topic :: Software Development :: Libraries", - ], - description="The hippynn python package - a modular library for atomistic machine learning with pytorch", - long_description=long_description, packages=setuptools.find_packages(), cmdclass=versioneer.get_cmdclass(), -) +) \ No newline at end of file From 110b8316196654857988d2eab12b255ada1ca2f0 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers <56895592+lubbersnick@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:48:49 -0600 Subject: [PATCH 5/6] Many documentation updates, small tweaks to database interface. (#100) --- AUTHORS.txt | 4 +- CHANGELOG.rst | 32 ++- COPYRIGHT.txt | 10 + LICENSE.txt | 10 - README.rst | 1 + docs/source/conf.py | 4 +- docs/source/examples/controller.rst | 1 - docs/source/examples/index.rst | 6 +- docs/source/examples/lightning.rst | 20 ++ docs/source/index.rst | 40 +++- docs/source/installation.rst | 4 +- docs/source/user_guide/ckernels.rst | 2 +- docs/source/user_guide/concepts.rst | 5 +- docs/source/user_guide/databases.rst | 43 +++- docs/source/user_guide/features.rst | 15 +- docs/source/user_guide/settings.rst | 5 + examples/ani1x_training.py | 2 +- hippynn/__init__.py | 17 +- hippynn/databases/__init__.py | 9 +- hippynn/databases/database.py | 261 ++++++++++++++++-------- hippynn/databases/h5_pyanitools.py | 76 ++++--- hippynn/databases/ondisk.py | 2 +- hippynn/experiment/__init__.py | 5 +- hippynn/experiment/lightning_trainer.py | 101 ++++++++- hippynn/experiment/routines.py | 4 +- hippynn/layers/pairs/filters.py | 5 +- hippynn/molecular_dynamics/__init__.py | 5 +- hippynn/molecular_dynamics/md.py | 1 + hippynn/tools.py | 19 +- 29 files changed, 530 insertions(+), 179 deletions(-) create mode 100644 COPYRIGHT.txt create mode 100644 docs/source/examples/lightning.rst diff --git a/AUTHORS.txt b/AUTHORS.txt index 4e0d97f7..8b6c4bac 100644 --- a/AUTHORS.txt +++ b/AUTHORS.txt @@ -19,7 +19,7 @@ Emily Shinkle (LANL) Michael G. Taylor (LANL) Jan Janssen (LANL) Cagri Kaymak (LANL) -Shuhao Zhang (CMU, LANL) +Shuhao Zhang (CMU, LANL) - Batched Optimization routines Also thanks to testing and feedback from: @@ -36,3 +36,5 @@ David Rosenberger Michael Tynes Drew Rohskopf Neil Mehta +Alice E A Allen + diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1618d1bd..6842e0d5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,23 +3,47 @@ Breaking changes: ----------------- +- set_e0_values has been renamed hierarchical_energy_initialization. The old name is + still provided but deprecated, and will be removed. + New Features: ------------- -- Added a new custom cuda kernel implementation using triton. These are highly performant and now the default implementation. -- Exporting a database to NPZ or H5 format after preprocessing is now just a function call away. -- SNAPjson format can now support an optional number of comment lines. -- Added Batch optimizer features in order to optimize geometries in parallel on the GPU. Algorithms include FIRE and BFGS. +- Added a new custom cuda kernel implementation using triton. + These are highly performant and now the default implementation. +- Exporting any database to NPZ or H5 format after preprocessing can be done with a method call. +- Database states can be cached to disk to simplify the restarting of training. +- Added batch geometry optimizer features in order to optimize geometries + in parallel on the GPU. Algorithms include FIRE, Newton-Raphson, and BFGS. +- Added experiment pytorch lightning trainer to provide for simple parallelized training. +- Added a molecular dynamics engine which includes the ability to batch over systems. +- Added examples pertaining to coarse graining. +- Added pair finders based on scipy KDTree for training to large systems. +- Added tool to drastically simplify creating ensemble models. The ensemblized graphs + are compatible with molecular dynamics codes such ASE and LAMMPS. +- Added the ability to weight different systems/atoms/bonds in a loss function. + Improvements: ------------- - Eliminated dependency on pyanitools for loading ANI-style H5 datasets. +- SNAPjson format can now support an optional number of comment lines. +- Added unit conversion options to the LAMMPS interface. +- Improved performance of bond order regression. +- It is now possible to limit the memory usage of the MLIAP interface in LAMMPS + using a library setting. +- Provide tunable regularization of HIP-NN-TS with an epsilon parameter, and + set the default to use a better value for epsilon. + Bug Fixes: ---------- - Fixed bug where custom kernels were not launching properly on non-default GPUs +- Fixed error when LAMMPS interface is in kokkos mode and the kokkos device was set to CPU. +- MLIAPInterface objects +- Fixed bug with RDF computer automatic initialization. 0.0.3 ======= diff --git a/COPYRIGHT.txt b/COPYRIGHT.txt new file mode 100644 index 00000000..aea758d6 --- /dev/null +++ b/COPYRIGHT.txt @@ -0,0 +1,10 @@ + +Copyright 2019. Triad National Security, LLC. All rights reserved. +This program was produced under U.S. Government contract 89233218CNA000001 for Los Alamos +National Laboratory (LANL), which is operated by Triad National Security, LLC for the U.S. +Department of Energy/National Nuclear Security Administration. All rights in the program are +reserved by Triad National Security, LLC, and the U.S. Department of Energy/National Nuclear +Security Administration. The Government is granted for itself and others acting on its behalf a +nonexclusive, paid-up, irrevocable worldwide license in this material to reproduce, prepare +derivative works, distribute copies to the public, perform publicly and display publicly, and to permit +others to do so. diff --git a/LICENSE.txt b/LICENSE.txt index af0925a1..2f40f860 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,15 +1,5 @@ -Copyright 2019. Triad National Security, LLC. All rights reserved. -This program was produced under U.S. Government contract 89233218CNA000001 for Los Alamos -National Laboratory (LANL), which is operated by Triad National Security, LLC for the U.S. -Department of Energy/National Nuclear Security Administration. All rights in the program are -reserved by Triad National Security, LLC, and the U.S. Department of Energy/National Nuclear -Security Administration. The Government is granted for itself and others acting on its behalf a -nonexclusive, paid-up, irrevocable worldwide license in this material to reproduce, prepare -derivative works, distribute copies to the public, perform publicly and display publicly, and to permit -others to do so. - This program is open source under the BSD-3 License. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/README.rst b/README.rst index 3bf6a020..45252206 100644 --- a/README.rst +++ b/README.rst @@ -106,6 +106,7 @@ The Journal of chemical physics, 148(24), 241715. See AUTHORS.txt for information on authors. See LICENSE.txt for licensing information. hippynn is licensed under the BSD-3 license. +See COPYRIGHT.txt for copyright information. Triad National Security, LLC (Triad) owns the copyright to hippynn, which it identifies as project number LA-CC-19-093. diff --git a/docs/source/conf.py b/docs/source/conf.py index a47dfe54..8e707bdf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -45,9 +45,11 @@ "no-show-inheritance": True, "special-members": "__init__", } +autodoc_member_order = 'bysource' + # The following are highly optional, so we mock them for doc purposes. -autodoc_mock_imports = ["pyanitools", "seqm", "schnetpack", "cupy", "lammps", "numba", "triton", "pytorch_lightning"] +autodoc_mock_imports = ["pyanitools", "seqm", "schnetpack", "cupy", "lammps", "numba", "triton", "pytorch_lightning", 'triton', 'scipy'] # -- Options for HTML output ------------------------------------------------- diff --git a/docs/source/examples/controller.rst b/docs/source/examples/controller.rst index 83de14a5..cc2d5016 100644 --- a/docs/source/examples/controller.rst +++ b/docs/source/examples/controller.rst @@ -1,7 +1,6 @@ Controller ========== - How to define a controller for more customized control of the training process. We assume that there is a set of ``training_modules`` assembled and a ``database`` object has been constructed. diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index 78703eac..d72ee5f3 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -3,8 +3,8 @@ Examples Here are some examples about how to use various features in ``hippynn``. Besides the :doc:`/examples/minimal_workflow` example, -the examples are just snippets. For runnable example scripts, see -`the examples at the hippynn github repository`_ +the examples are just snippets, rather than full scripts. +For runnable example scripts, see `the examples at the hippynn github repository`_ .. _`the examples at the hippynn github repository`: https://github.com/lanl/hippynn/tree/development/examples @@ -23,5 +23,5 @@ the examples are just snippets. For runnable example scripts, see mliap_unified excited_states weighted_loss - + lightning diff --git a/docs/source/examples/lightning.rst b/docs/source/examples/lightning.rst new file mode 100644 index 00000000..bb572426 --- /dev/null +++ b/docs/source/examples/lightning.rst @@ -0,0 +1,20 @@ +Pytorch Lightning module +======================== + + +Hippynn incldues support for distributed training using `pytorch-lightning`_. +This can be accessed using the :class:`hippynn.experiment.HippynnLightningModule` class. +The class has two class-methods for creating the lightning module using the same +types of arguments that would be used for an ordinary hippynn experiment. +These are :meth:`hippynn.experiment.HippynnLightningModule.from_experiment_setup` +and :meth:`hippynn.experiment.HippynnLightningModule.from_train_setup`. +Alternatively, you may construct and supply the arguments for the module yourself. + +Finally, in additional to the usual pytorch lightning arguments, +the hippynn lightning module saves an additional file, `experiment_structure.pt`, +which needs to be provided as an argument to the +:meth:`hippynn.experiment.HippynnLightningModule.load_from_checkpoint` constructor. + + +.. _pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning + diff --git a/docs/source/index.rst b/docs/source/index.rst index 50bcd450..fc17eb26 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -8,31 +8,53 @@ We hope you enjoy your stay. What is hippynn? ================ -`hippynn` is a python library for machine learning on atomistic systems. +``hippynn`` is a python library for machine learning on atomistic systems +using `pytorch`_. We aim to provide high-performance modular design so that different components can be re-used, extended, or added to. You can find more information -at the :doc:`/user_guide/features` page. The development home is located -at `the hippynn github repository`_, which also contains `many example files`_ +about overall library features at the :doc:`/user_guide/features` page. +The development home is located at `the github github repository`_, which also contains `many example files`_. +Additionally, the :doc:`user guide ` aims to describe abstract +aspects of the library, while the +:doc:`examples documentation section ` aims to show +more concretely how to perform tasks with hippynn. Finally, the +:doc:`api documentation ` contains a comprehensive +listing of the library components and their documentation. The main components of hippynn are constructing models, loading databases, training the models to those databases, making predictions on new databases, -and interfacing with other atomistic codes. In particular, we provide interfaces -to `ASE`_ (prediction), `PYSEQM`_ (training/prediction), and `LAMMPS`_ (prediction). +and interfacing with other atomistic codes for operations such as molecular dynamics. +In particular, we provide interfaces to `ASE`_ (prediction), +`PYSEQM`_ (training/prediction), and `LAMMPS`_ (prediction). hippynn is also used within `ALF`_ for generating machine learned potentials along with their training data completely from scratch. -Multiple formats for training data are supported, including -Numpy arrays, the ASE Database, `fitSNAP`_ JSON format, and `ANI HDF5 files`_. +Multiple :doc:`database formats ` for training data are supported, including +Numpy arrays, `ASE`_-compatible formats, `FitSNAP`_ JSON format, and `ANI HDF5 files`_. + +``hippynn`` includes many tools, such as an :doc:`ASE calculator`, +a :doc:`LAMMPS MLIAP interface`, +:doc:`batched prediction ` and batched geometry optimization, +:doc:`automatic ensemble creation `, +:doc:`restarting training from checkpoints `, +:doc:`sample-weighted loss functions `, +:doc:`distributed training with pytorch lightning `, +and more. + +``hippynn`` is highly modular, and if you are a model developer, interfacing your +pytorch model into the hippynn node/graph system will make it simple and easy for users +to build models of energy, charge, bond order, excited state energies, and more. .. _`ASE`: https://wiki.fysik.dtu.dk/ase/ .. _`PYSEQM`: https://github.com/lanl/PYSEQM/ .. _`LAMMPS`: https://www.lammps.org -.. _`fitSNAP`: https://github.com/FitSNAP/FitSNAP +.. _`FitSNAP`: https://github.com/FitSNAP/FitSNAP .. _`ANI HDF5 files`: https://doi.org/10.1038/s41597-020-0473-z .. _`ALF`: https://github.com/lanl/ALF/ -.. _`the hippynn github repository`: https://github.com/lanl/hippynn/ +.. _`the github github repository`: https://github.com/lanl/hippynn/ .. _`many example files`: https://github.com/lanl/hippynn/tree/development/examples +.. _`pytorch`: https://pytorch.org .. toctree:: diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 4064fea9..c8a07152 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -2,7 +2,6 @@ Installation ============ - Requirements ^^^^^^^^^^^^ @@ -43,6 +42,8 @@ Interfacing codes: .. _LAMMPS: https://www.lammps.org/ .. _PYSEQM: https://github.com/lanl/PYSEQM .. _pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning +.. _hippynn: https://github.com/lanl/hippynn/ + Installation Instructions ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -67,7 +68,6 @@ Clone the hippynn_ repository and navigate into it, e.g.:: $ git clone https://github.com/lanl/hippynn.git $ cd hippynn -.. _hippynn: https://github.com/lanl/hippynn/ Dependencies using conda diff --git a/docs/source/user_guide/ckernels.rst b/docs/source/user_guide/ckernels.rst index c810bbcd..eb504da7 100644 --- a/docs/source/user_guide/ckernels.rst +++ b/docs/source/user_guide/ckernels.rst @@ -60,7 +60,7 @@ The three custom kernels correspond to the interaction sum in hip-nn: .. math:: - a'_{i,a} = = \sum_{\nu,b} V^\nu_{a,b} e^{\nu}_{i,b} + a'_{i,a} = \sum_{\nu,b} V^\nu_{a,b} e^{\nu}_{i,b} e^{\nu}_{i,a} = \sum_p s^\nu_{p} z_{p_j,a} diff --git a/docs/source/user_guide/concepts.rst b/docs/source/user_guide/concepts.rst index 62cd8ff5..79b4faf6 100644 --- a/docs/source/user_guide/concepts.rst +++ b/docs/source/user_guide/concepts.rst @@ -45,8 +45,9 @@ Graphs A :class:`~hippynn.graphs.GraphModule` is a 'compiled' set of nodes; a ``torch.nn.Module`` that executes the graph. -GraphModules are used in a number of places within hippynn. - +GraphModules are used in a number of places within hippynn, +such as the model, the loss, the evaluator, the predictor, the ASE interface, +and the LAMMPS interface objects all use GraphModules. Experiment ^^^^^^^^^^ diff --git a/docs/source/user_guide/databases.rst b/docs/source/user_guide/databases.rst index 4448033b..2b339cf7 100644 --- a/docs/source/user_guide/databases.rst +++ b/docs/source/user_guide/databases.rst @@ -31,12 +31,45 @@ the [i,j] element of the cell gives the j cartesian coordinate of cell vector i. massive difficulties fitting to periodic boundary conditions, you may check the transposed version of your cell data, or compute the RDF. +Database Formats and notes +--------------------------- -ASE Objects Database handling ----------------------------------------------------------- -If your training data is stored as ASE files of any type (.json,.db,.xyz,.traj ... etc.) it can be loaded directly -a Database for hippynn. +Numpy arrays on disk +........................ + +see :class:`hippynn.databases.NPZDatabase` (if arrays are stored +in a `.npz` dictionary) or :class:`hippynn.databases.DirectoryDatabase` +(if each array is in its own file). + +Numpy arrays in memory +........................ + +Use the base :class:`hippynn.databases.Database` class directly to initialize +a database from a dictionary mapping db_names to numpy arrays. + +pyanitools H5 files +........................ + +See :class:`hippynn.databases.PyAniFileDB` and see :class:`hippynn.databases.PyAniDirectoryDB`. + +This format requires ``h5py`` and ``ase`` to be installed. + +Snap JSON Format +........................ + +See :class:`hippynn.databases.SNAPDirectoryDatabase`. This format requires ``ase`` to be installed. + +For more information on this format, see the FitSNAP_ software. + +.. _FitSNAP: https://fitsnap.github.io + +ASE Database +........................ + +If your training data is stored as ASE files of any type, +(.json,.db,.xyz,.traj ... etc.) it can be loaded directly +as a Database for hippynn. The ASE database :class:`~hippynn.databases.AseDatabase` can be loaded with ASE installed. -See ~/examples/ase_db_example.py for a basic example utilzing the class. \ No newline at end of file +See ~/examples/ase_db_example.py for a basic example utilizing the class. \ No newline at end of file diff --git a/docs/source/user_guide/features.rst b/docs/source/user_guide/features.rst index 06fac16f..b95d6158 100644 --- a/docs/source/user_guide/features.rst +++ b/docs/source/user_guide/features.rst @@ -11,7 +11,7 @@ Modular set of pytorch layers for atomistic operations if you want to use them in your scripts without using the rest of the features provided here -- no problem! -API documentation for :mod:`~hippynn.layers` +API documentation for :mod:`~hippynn.layers` and :mod:`~hippynn.networks` Graph level API for simple and flexible construction of models from pytorch components. --------------------------------------------------------------------------------------- @@ -26,6 +26,12 @@ Graph level API for simple and flexible construction of models from pytorch comp API documentation for :mod:`~hippynn.graphs` +For more information on nodes and graphs, see the `graph exploration ipython notebook`_ which can also +be found in the example files. + +.. _graph exploration ipython notebook: https://github.com/lanl/hippynn/blob/development/examples/graph_exploration.ipynb + + Plot level API for tracking your training. ---------------------------------------------------------- - Using the graph API, define quantities to evaluate before, during, or after training as @@ -46,7 +52,7 @@ API documentation for :mod:`~hippynn.experiment` Custom Kernels for fast execution ---------------------------------------------------------- - Certain operations are not efficiently written in pure pytorch, we provide - alternative implementations with ``numba`` + alternative implementations. - These are directly linked in with pytorch Autograd -- use them like native pytorch functions. - These provide advantages in memory footprint and speed - Includes CPU and GPU execution for custom kernels @@ -55,7 +61,8 @@ More information at :doc:`this page ` Interfaces ---------------------------------------------------------- -- ASE: Define `ase` calculators based on the graph-level API. -- PYSEQM: Use `pyseqm` calculations as nodes in a graph. +- ASE: Define ``ase`` calculators based on the graph-level API. +- PYSEQM: Use ``pyseqm`` calculations as nodes in a graph. +- LAMMPS: Create a file for use as a `pair style mliap` object. API documentation for :mod:`~hippynn.interfaces` \ No newline at end of file diff --git a/docs/source/user_guide/settings.rst b/docs/source/user_guide/settings.rst index c6764206..a4c8fcef 100644 --- a/docs/source/user_guide/settings.rst +++ b/docs/source/user_guide/settings.rst @@ -69,3 +69,8 @@ The following settings are available: - float between 0 and 1 - 1.0 - no + * - TIMEPLOT_AUTOSCALING + - If True, only provide log-scaled plots of training quantities over time if warranted by the data. If False, always produce all plots in linear, log, and loglog scales. + - bool + - True + - yes diff --git a/examples/ani1x_training.py b/examples/ani1x_training.py index f97f6114..6d1608d0 100644 --- a/examples/ani1x_training.py +++ b/examples/ani1x_training.py @@ -108,7 +108,7 @@ def load_db(db_info, en_name, force_name, seed, anidata_location, n_workers): found_indices = ~np.isnan(database.arr_dict[en_name]) database.arr_dict = {k: v[found_indices] for k, v in database.arr_dict.items()} - database.make_trainvalidtest_split(0.1, 0.1) + database.make_trainvalidtest_split(test_size=0.1, valid_size=0.1) return database diff --git a/hippynn/__init__.py b/hippynn/__init__.py index 520356ff..bf59fa7d 100644 --- a/hippynn/__init__.py +++ b/hippynn/__init__.py @@ -7,27 +7,36 @@ from . import _version __version__ = _version.get_versions()['version'] -# Configurational settings +# Configuration settings from ._settings_setup import settings - # Pytorch modules from . import layers -from . import networks +from . import networks # wait this one is different from the other one. # Graph abstractions from . import graphs +from .graphs import nodes, IdxType, GraphModule, Predictor # Database loading from . import databases +from .databases import Database, NPZDatabase, DirectoryDatabase # Training/testing routines from . import experiment -from .experiment import setup_and_train +from .experiment import setup_and_train, train_model, setup_training,\ + test_model, load_model_from_cwd, load_checkpoint, load_checkpoint_from_cwd + +# Other subpackages +from . import molecular_dynamics +from . import optimizer # Custom Kernels from . import custom_kernels +from .custom_kernels import set_custom_kernels from . import pretraining +from .pretraining import hierarchical_energy_initialization from . import tools +from .tools import active_directory, log_terminal diff --git a/hippynn/databases/__init__.py b/hippynn/databases/__init__.py index e97ad715..91aca915 100644 --- a/hippynn/databases/__init__.py +++ b/hippynn/databases/__init__.py @@ -23,14 +23,15 @@ pass if has_ase: - from ..interfaces.ase_interface import AseDatabase - if has_h5: - from .h5_pyanitools import PyAniFileDB, PyAniDirectoryDB + from ..interfaces.ase_interface import AseDatabase + from .SNAPJson import SNAPDirectoryDatabase + if has_h5: + from .h5_pyanitools import PyAniFileDB, PyAniDirectoryDB all_list = ["Database", "DirectoryDatabase", "NPZDatabase"] if has_ase: - all_list += ["AseDatabase"] + all_list += ["AseDatabase", "SNAPDirectoryDatabase"] if has_h5: all_list += ["PyAniFileDB", "PyAniDirectoryDB"] __all__ = all_list diff --git a/hippynn/databases/database.py b/hippynn/databases/database.py index fa503763..15bdd93b 100644 --- a/hippynn/databases/database.py +++ b/hippynn/databases/database.py @@ -15,6 +15,7 @@ _AUTO_SPLIT_PREFIX = "split_mask_" + class Database: """ Class for holding a pytorch dataset, splitting it, generating dataloaders, etc." @@ -22,18 +23,18 @@ class Database: def __init__( self, - arr_dict: dict[str,torch.Tensor], + arr_dict: dict[str, np.ndarray], inputs: list[str], targets: list[str], - seed: [int,np.random.RandomState,tuple], - test_size: Union[float,int]=None, - valid_size: Union[float,int]=None, - num_workers: int=0, - pin_memory: bool=True, - allow_unfound:bool =False, - auto_split:bool =False, - device: torch.device=None, - dataloader_kwargs:dict[str,object]=None, + seed: [int, np.random.RandomState, tuple], + test_size: Union[float, int] = None, + valid_size: Union[float, int] = None, + num_workers: int = 0, + pin_memory: bool = True, + allow_unfound: bool = False, + auto_split: bool = False, + device: torch.device = None, + dataloader_kwargs: dict[str, object] = None, quiet=False, ): """ @@ -56,7 +57,7 @@ def __init__( :param quiet: If True, print little or nothing while loading. """ - # Restartable Children of this class should change this after super().__init__ . + # Restartable Children of this class should change this after calling super().__init__() . self.restarter = NoRestart() self.inputs = inputs @@ -75,11 +76,12 @@ def __init__( _var_list = self.var_list except RuntimeError: if not quiet: - print("Database inputs and/or targets not specified. " - "The database will not be checked against and model inputs and targets (db_info).") + print( + "Database inputs and/or targets not specified. " + "The database will not be checked against and model inputs and targets (db_info)." + ) _var_list = [] - for k in _var_list: if k not in arr_dict and k not in ("indices", "split_indices"): if allow_unfound: @@ -113,13 +115,15 @@ def __init__( if self.auto_split: if test_size is not None or valid_size is not None: - warnings.warn(f"Auto split was set but test and valid size was also set." - f" Ignoring supplied test and validation sizes ({test_size} and {valid_size}.") + warnings.warn( + f"Auto split was set but test and valid size was also set." + f" Ignoring supplied test and validation sizes ({test_size} and {valid_size}." + ) self.make_automatic_splits() if test_size is not None or valid_size is not None: if test_size is None or valid_size is None: - raise ValueError("Both test and valid size must be set for auto-splitting based on fractions") + raise ValueError("Both test_size and valid_size must be set for splitting when creating a database.") else: self.make_trainvalidtest_split(test_size=test_size, valid_size=valid_size) @@ -142,11 +146,16 @@ def var_list(self): raise RuntimeError(f"Database inputs not defined, set {Database}.targets.") return self.inputs + self.targets - def send_to_device(self, device=None): + def send_to_device(self, device: torch.device = None): """ Move the database to an accelerator device if possible. In some circumstances this can accelerate training. + .. Note:: + If the database is moved to a GPU, + pin_memory will be set to False + and num_workers will be set to 0. + :param device: device to move to, if None, try to auto-detect. :return: """ @@ -167,11 +176,13 @@ def send_to_device(self, device=None): for split, arrdict in self.splits.items(): for k in arrdict: arrdict[k] = arrdict[k].to(device) + return - def make_random_split(self, evaluation_mode, split_size): + def make_random_split(self, split_name: str, split_size: Union[int, float]): """ + Make a random split using self.random_state to select items. - :param evaluation_mode: String naming the split, can be anything, but 'train', 'valid', and 'test' are special.s + :param split_name: String naming the split, can be anything, but 'train', 'valid', and 'test' are special. :param split_size: int (number of items) or float<1, fraction of samples. :return: """ @@ -185,9 +196,25 @@ def make_random_split(self, evaluation_mode, split_size): split_indices.sort() - return self.make_explicit_split(evaluation_mode, split_indices) + return self.make_explicit_split(split_name, split_indices) + + def make_trainvalidtest_split(self, *, test_size: Union[int, float], valid_size: Union[int, float]): + """ + Make a split for train, valid, and test out of any remaining unsplit entries in the database. + The size is specified in terms of test and valid splits; the train split will be the remainder. + + If you wish to specify precise rows for each split, see `make_explict_split` + or `make_explicit_split_bool`. + + This function takes keyword-arguments only in order to prevent confusion over which + size is which. + + The types of both test_size and valid_size parameters must match. - def make_trainvalidtest_split(self, test_size, valid_size): + :param test_size: int (count) or float (fraction) of data to assign to test split + :param valid_size: int (count) or float (fraction) of data to assign to valid split + :return: None + """ if self.splitting_completed: raise RuntimeError("Database already split!") @@ -196,19 +223,18 @@ def make_trainvalidtest_split(self, test_size, valid_size): raise ValueError("If train or valid size is set as a fraction, then set test_size as a fraction") else: if valid_size + test_size > 1: - raise ValueError( - f"Test fraction ({test_size}) plus valid fraction " f"({valid_size}) are greater than 1!" - ) + raise ValueError(f"Test fraction ({test_size}) plus valid fraction " f"({valid_size}) are greater than 1!") valid_size /= 1 - test_size self.make_random_split("test", test_size) self.make_random_split("valid", valid_size) self.split_the_rest("train") + return - def make_explicit_split(self, evaluation_mode, split_indices): + def make_explicit_split(self, split_name:str, split_indices: np.ndarray): """ - :param evaluation_mode: name for split, typically 'train', 'valid', 'test' + :param split_name: name for split, typically 'train', 'valid', 'test' :param split_indices: the indices of the items for the split :return: """ @@ -227,18 +253,18 @@ def make_explicit_split(self, evaluation_mode, split_indices): where_complement = np.where(complement_mask) # Split off data, and keep the rest. - self.splits[evaluation_mode] = {k: torch.from_numpy(self.arr_dict[k][where_index]) for k in self.arr_dict} - if "split_indices" not in self.splits[evaluation_mode]: + self.splits[split_name] = {k: torch.from_numpy(self.arr_dict[k][where_index]) for k in self.arr_dict} + if "split_indices" not in self.splits[split_name]: if not self.quiet: - print(f"Adding split indices for split: {evaluation_mode}") - self.splits[evaluation_mode]["split_indices"] = torch.arange(len(split_indices), dtype=torch.int64) + print(f"Adding split indices for split: {split_name}") + self.splits[split_name]["split_indices"] = torch.arange(len(split_indices), dtype=torch.int64) for k, v in self.arr_dict.items(): self.arr_dict[k] = v[where_complement] if not self.quiet: - print(f"Arrays for split: {evaluation_mode}") - prettyprint_arrays(self.splits[evaluation_mode]) + print(f"Arrays for split: {split_name}") + prettyprint_arrays(self.splits[split_name]) if arrdict_len(self.arr_dict) == 0: if not self.quiet: @@ -246,25 +272,28 @@ def make_explicit_split(self, evaluation_mode, split_indices): self.splitting_completed = True return - def make_explicit_split_bool(self, evaluation_mode, split_mask): + def make_explicit_split_bool(self, split_name: str, + split_mask: Union[np.ndarray, torch.tensor]): """ - :param evaluation_mode: name for split, typically 'train', 'valid', 'test' + :param split_name: name for split, typically 'train', 'valid', 'test' :param split_mask: a boolean array for where to split :return: """ + if isinstance(split_mask, torch.tensor): + split_mask = split_mask.numpy() if split_mask.dtype != np.bool_: if not np.isin(split_mask, [0, 1]).all(): raise ValueError(f"Mask function contains invalid values. Values found: {np.unique(split_mask)}") else: split_mask = split_mask.astype(np.bool_) - indices = self.arr_dict['indices'][split_mask] - self.make_explicit_split(evaluation_mode, indices) + indices = self.arr_dict["indices"][split_mask] + self.make_explicit_split(split_name, indices) return - def split_the_rest(self, evaluation_mode): - self.make_explicit_split(evaluation_mode, self.arr_dict["indices"]) + def split_the_rest(self, split_name: str): + self.make_explicit_split(split_name, self.arr_dict["indices"]) self.splitting_completed = True return @@ -296,9 +325,9 @@ def add_split_masks(self, dict_to_add_to=None, split_prefix=None): for sprime, split in self.splits.items(): if sprime == s: - mask = np.ones_like(split['indices'], dtype=np.bool_) + mask = np.ones_like(split["indices"], dtype=np.bool_) else: - mask = np.zeros_like(split['indices'], dtype=np.bool_) + mask = np.zeros_like(split["indices"], dtype=np.bool_) if write_tensor: mask = torch.as_tensor(mask) @@ -336,7 +365,7 @@ def make_automatic_splits(self, split_prefix=None, dry_run=False): if k.startswith(split_prefix): if arr.ndim != 1: raise ValueError(f"Split mask for '{k}' has too many dimensions. Shape: {arr.shape=}") - if arr.dtype == np.dtype('bool'): + if arr.dtype == np.dtype("bool"): mask_vars.add(k) elif arr.dtype is np.int and arr.ndim == 1: if np.isin(arr, [0, 1]).all(): @@ -350,7 +379,7 @@ def make_automatic_splits(self, split_prefix=None, dry_run=False): if not len(mask_vars): raise ValueError("No split mask detected.") - masks = {k[len(split_prefix):]: self.arr_dict[k].astype(bool) for k in mask_vars} + masks = {k[len(split_prefix) :]: self.arr_dict[k].astype(bool) for k in mask_vars} if not self.quiet: print("Auto-detected splits:", list(masks.keys())) @@ -369,13 +398,15 @@ def make_automatic_splits(self, split_prefix=None, dry_run=False): mask_counts += arr.astype(int) if not (mask_counts == 1).all(): set_of_counts = set(mask_counts) - raise ValueError(f" Auto-splitting requires unique split for each item." + - f" Items with the following split counts were detected: {set_of_counts}") + raise ValueError( + f" Auto-splitting requires unique split for each item." + + f" Items with the following split counts were detected: {set_of_counts}" + ) if dry_run: return - masks = {k: self.arr_dict['indices'][m] for k, m in masks.items()} + masks = {k: self.arr_dict["indices"][m] for k, m in masks.items()} for k, m in masks.items(): self.make_explicit_split(k, m) @@ -388,11 +419,18 @@ def make_automatic_splits(self, split_prefix=None, dry_run=False): return - def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample=False): + def make_generator(self, + split_name: str, + evaluation_mode: str, + batch_size: Union[int, None] = None, + subsample: Union[float, bool] = False + ): """ Makes a dataloader for the given type of split and evaluation mode of the model. - :param split_type: str; "train", "valid", or "test" ; selects data to use + In most cases, you do not need to call this function directly as a user. + + :param split_name: str; "train", "valid", or "test" ; selects data to use :param evaluation_mode: str; "train" or "eval". Used for whether to shuffle. :param batch_size: passed to pytorch :param subsample: fraction to subsample @@ -402,16 +440,14 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample if not self.splitting_completed: raise ValueError("Database has not yet been split.") - if split_type not in self.splits: - raise ValueError(f"Split {split_type} Invalid. Current splits:{list(self.splits.keys())}") + if split_name not in self.splits: + raise ValueError(f"Split {split_name} Invalid. Current splits:{list(self.splits.keys())}") - data = [self.splits[split_type][k] for k in self.var_list] + data = [self.splits[split_name][k] for k in self.var_list] if evaluation_mode == "train": - if split_type != "train": - raise ValueError( - "evaluation mode 'train' can only be used with training data." "(got {})".format(split_type) - ) + if split_name != "train": + raise ValueError("evaluation mode 'train' can only be used with training data." "(got {})".format(split_name)) shuffle = True elif evaluation_mode == "eval": shuffle = False @@ -423,9 +459,7 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample n_total = data[0].shape[0] n_selected = int(n_total * subsample) sampled_indices = torch.argsort(torch.rand(n_total))[:n_selected] - # sampled_indices = torch.rand(data[0].shape[0]) < subsample dataset = Subset(dataset, sampled_indices) - # data = [a[sampled_indices] for a in data] generator = DataLoader( dataset, @@ -463,8 +497,8 @@ def _array_stat_helper(self, key, species_key, atomwise, norm_per_atom, norm_axi n_atoms = (self.arr_dict[species_key] > 0).sum(axis=1) # Transposes broadcast the result rightwards instead of leftwards. # numpy transpose on higher-order arrays reverses all dimensions. - prop = (prop.T/n_atoms).T - stat_prop = (stat_prop.T/n_atoms).T + prop = (prop.T / n_atoms).T + stat_prop = (stat_prop.T / n_atoms).T mean = stat_prop.mean() std = stat_prop.std() @@ -473,8 +507,16 @@ def _array_stat_helper(self, key, species_key, atomwise, norm_per_atom, norm_axi return prop, mean, std - - def remove_high_property(self, key, atomwise, norm_per_atom=False, species_key=None, cut=None, std_factor=10, norm_axis=None): + def remove_high_property( + self, + key: str, + atomwise: bool, + norm_per_atom: bool = False, + species_key: str = None, + cut: Union[float, None] = None, + std_factor: Union[float, None] = 10, + norm_axis: Union[int, None] = None, + ): """ For removing outliers from a dataset. Use with caution; do not inadvertently remove outliers from benchmarks! @@ -505,7 +547,7 @@ def remove_high_property(self, key, atomwise, norm_per_atom=False, species_key=N if std_factor is not None: prop, mean, std = self._array_stat_helper(key, species_key, atomwise, norm_per_atom, norm_axis) - large_property_mask = np.abs(prop - mean)/std > std_factor + large_property_mask = np.abs(prop - mean) / std > std_factor # Scan over all non-batch indices. non_batch_axes = tuple(range(1, prop.ndim)) drop_mask = np.sum(large_property_mask, axis=non_batch_axes) > 0 @@ -514,7 +556,23 @@ def remove_high_property(self, key, atomwise, norm_per_atom=False, species_key=N print(f"Removed {drop_mask.astype(int).sum()} outlier systems in variable {key} due to std. factor.") self.make_explicit_split(f"failed_std_fac_{key}", indices) - def write_h5(self, split=None, h5path=None, species_key='species', overwrite=False): + def write_h5(self, + split: Union[str, None] = None, + h5path: Union[str, None] = None, + species_key: str = "species", + overwrite:bool = False): + """ + Write this database to the pyanitools h5 format. + See :func:`hippynn.databases.h5_pyanitools.write_h5` for details. + + Note: This function will error if h5py is not installed. + + :param split: + :param h5path: + :param species_key: + :param overwrite: + :return: + """ try: from .h5_pyanitools import write_h5 as write_h5_function @@ -523,33 +581,41 @@ def write_h5(self, split=None, h5path=None, species_key='species', overwrite=Fal return write_h5_function(self, split=split, file=h5path, species_key=species_key, overwrite=overwrite) - def write_npz(self, file: str, record_split_masks: bool = True, compressed:bool =True, overwrite: bool = False, split_prefix=None, return_only=False): + def write_npz( + self, + file: str, + record_split_masks: bool = True, + compressed: bool = True, + overwrite: bool = False, + split_prefix: Union[str, None] = None, + return_only: bool = False, + ): """ :param file: str, Path, or file object compatible with np.save - :param record_split_masks: + :param record_split_masks: whether to generate and place masks for the splits into the saved database. + :param compressed: whether to use np.savez_compressed (True) or np.savez :param overwrite: Whether to accept an existing path. Only used if fname is str or path. - :param split_prefix: optionally change the prefix for the masks computed by the splits. + :param split_prefix: optionally override the prefix for the masks computed by the splits. :param return_only: if True, ignore the file string and just return the resulting dictionary of numpy arrays. - :return: + :return: """ if split_prefix is None: split_prefix = _AUTO_SPLIT_PREFIX if not self.splitting_completed: - raise ValueError("Cannot write an incompletely split database to npz file.\n" + - "You can split the rest using `database.split_the_rest('other_data')`\n" + - "to put the remaining data into a new split named 'other_data'") + raise ValueError( + "Cannot write an incompletely split database to npz file.\n" + + "You can split the rest using `database.split_the_rest('other_data')`\n" + + "to put the remaining data into a new split named 'other_data'" + ) # get combined dictionary of arrays. - np_dict = {sname: - {arr_name: array.to('cpu').numpy() for arr_name, array in split.items()} - for sname, split in self.splits.items()} + np_dict = {sname: {arr_name: array.to("cpu").numpy() for arr_name, array in split.items()} for sname, split in self.splits.items()} # insert split masks if requested. if record_split_masks: self.add_split_masks(dict_to_add_to=np_dict, split_prefix=split_prefix) - # Stack numpy arrays: arr_dict = {} a_split = list(np_dict.values())[0] @@ -577,10 +643,12 @@ def write_npz(self, file: str, record_split_masks: bool = True, compressed:bool return arr_dict - def sort_by_index(self, index_name='indices'): + def sort_by_index(self, index_name: str = "indices"): """ + Sort arrays in each split of the database by an index key. - The default is 'indices', also possible is 'split_indices', or any other variable name in the database. + + The default is 'indices', also possible is 'split_indices', or any other variable name in the database. :param index_name: :return: None @@ -592,12 +660,14 @@ def sort_by_index(self, index_name='indices'): for k, v in split.items(): split[k] = v[ind_order] - def trim_by_species(self, species_key: str, keep_splits_same_size: bool =True): + def trim_by_species(self, species_key: str, keep_splits_same_size: bool = True): """ Remove any excess padding in a database. + :param species_key: what array to use to mark atom presence. - :param keep_splits_same_size: true: trim by the minimum amount across splits, false: trim by the maximum amount for each split. - :return: + :param keep_splits_same_size: true: trim by the minimum amount across splits, + false: trim by the maximum amount for each split. + :return: None """ if not self.splitting_completed: raise ValueError("Cannot trim arrays until splitting has been completed.") @@ -653,7 +723,12 @@ def trim_by_species(self, species_key: str, keep_splits_same_size: bool =True): return - def get_device(self): + def get_device(self) -> torch.device: + """ + Determine what device the database resides on. Raises ValueError if multiple devices are encountered. + + :return: device. + """ if not self.splitting_completed: raise ValueError("Device should not be changed before splitting is complete.") @@ -664,11 +739,15 @@ def get_device(self): device = devices.pop() return device - def make_database_cache(self, file="./hippynn_db_cache.npz", overwrite=False, **override_kwargs): + def make_database_cache(self, file: str = "./hippynn_db_cache.npz", overwrite: bool = False, **override_kwargs) -> "Database": """ Cache the database as-is, and re-open it. Useful for creating an easy restart script if the storage space is available. + The new datatbase will by default inherit the properties of this database. + + usage: + >>> database = database.make_database_cache() :param file: where to store the database :param overwrite: whether to overwrite an existing cache file with this name. @@ -702,14 +781,20 @@ def make_database_cache(self, file="./hippynn_db_cache.npz", overwrite=False, ** if not self.quiet: print("Writing Cached database to", file) - self.write_npz(file=file, - record_split_masks=True, # allows inheriting of splits from this db. - overwrite=overwrite, - return_only=False) + self.write_npz( + file=file, record_split_masks=True, overwrite=overwrite, return_only=False # allows inheriting of splits from this db. + ) # now reload cached file. return NPZDatabase(**arguments) -def compute_index_mask(indices, index_pool): + +def compute_index_mask(indices: np.ndarray, index_pool: np.ndarray) -> np.ndarray: + """ + + :param indices: + :param index_pool: + :return: + """ if not np.all(np.isin(indices, index_pool)): raise ValueError("Provided indices not in database") @@ -723,9 +808,9 @@ def compute_index_mask(indices, index_pool): return index_mask -def prettyprint_arrays(arr_dict): +def prettyprint_arrays(arr_dict: dict[str: np.ndarray]): """ - Pretty-print array dictionary + Pretty-print array dictionary. :return: None """ column_format = "| {:<30} | {:<18} | {:<28} |" diff --git a/hippynn/databases/h5_pyanitools.py b/hippynn/databases/h5_pyanitools.py index fd4c6a27..a50a9f2d 100644 --- a/hippynn/databases/h5_pyanitools.py +++ b/hippynn/databases/h5_pyanitools.py @@ -1,7 +1,7 @@ """ -Read Databases in the ANI H5 format. -Note: You will need `pyanitools.py` to be importable to import this module. +Read Databases in the pyanitools H5 format. + """ import os @@ -37,8 +37,7 @@ def extract_full_file(self, file, species_key="species"): for c in progress_bar(x, desc="Data Groups", unit="group", total=x.group_size()): batch_dict = {} if species_key not in c: - raise ValueError(f"Species key '{species_key}' not found' in file {file}!\n" - f"\tFound keys: {set(c.keys())}") + raise ValueError(f"Species key '{species_key}' not found' in file {file}!\n" f"\tFound keys: {set(c.keys())}") for k, v in c.items(): # Filter things we don't need if k in self._IGNORE_KEYS: @@ -104,18 +103,19 @@ def determine_key_structure(self, batch_list, sys_count, n_atoms_max, species_ke shape_scheme[k][axis] = n_atoms_max shape_scheme[k][0] = sys_count - padding_scheme['sys_number'] = [] + padding_scheme["sys_number"] = [] return padding_scheme, shape_scheme, bkey def process_batches(self, batches, n_atoms_max, sys_count, species_key="species"): # Get padding abd shape info and batch size key - padding_scheme, shape_scheme, size_key =\ - self.determine_key_structure(batches, sys_count, n_atoms_max, species_key=species_key) + padding_scheme, shape_scheme, size_key = self.determine_key_structure(batches, sys_count, n_atoms_max, species_key=species_key) # add system numbers to the final arrays - shape_scheme['sys_number'] = [sys_count, ] - batches[0]['sys_number'] = np.asarray([0], dtype=np.int64) + shape_scheme["sys_number"] = [ + sys_count, + ] + batches[0]["sys_number"] = np.asarray([0], dtype=np.int64) arr_dict = {} for k, shape in shape_scheme.items(): @@ -126,7 +126,7 @@ def process_batches(self, batches, n_atoms_max, sys_count, species_key="species" for i, b in enumerate(progress_bar(batches, desc="Processing Batches", unit="batch")): # Get batch metadata n_sys = b[size_key].shape[0] - b['sys_number'] = np.asarray([i], dtype=np.int64) + b["sys_number"] = np.asarray([i], dtype=np.int64) sys_end = sys_start + n_sys # n_atoms_batch = b[species_key].shape[1] # don't need this! @@ -173,7 +173,7 @@ def filter_arrays(self, arr_dict, allow_unfound=False, quiet=False): class PyAniFileDB(Database, PyAniMethods, Restartable): - def __init__(self, file, inputs, targets, *args, allow_unfound=False, species_key="species", quiet=False, driver='core', **kwargs): + def __init__(self, file, inputs, targets, *args, allow_unfound=False, species_key="species", quiet=False, driver="core", **kwargs): """ :param file: @@ -197,7 +197,14 @@ def __init__(self, file, inputs, targets, *args, allow_unfound=False, species_ke super().__init__(arr_dict, inputs, targets, *args, **kwargs, quiet=quiet, allow_unfound=allow_unfound) self.restarter = self.make_restarter( - file, inputs, targets, *args, **kwargs, driver=driver, quiet=quiet, allow_unfound=allow_unfound, + file, + inputs, + targets, + *args, + **kwargs, + driver=driver, + quiet=quiet, + allow_unfound=allow_unfound, species_key=species_key, ) @@ -211,8 +218,19 @@ def load_arrays(self, allow_unfound=False, quiet=False): class PyAniDirectoryDB(Database, PyAniMethods, Restartable): - def __init__(self, directory, inputs, targets, *args, files=None, allow_unfound=False, species_key="species", - quiet=False, driver='core', **kwargs): + def __init__( + self, + directory, + inputs, + targets, + *args, + files=None, + allow_unfound=False, + species_key="species", + quiet=False, + driver="core", + **kwargs, + ): self.directory = directory self.files = files @@ -221,11 +239,10 @@ def __init__(self, directory, inputs, targets, *args, files=None, allow_unfound= self.species_key = species_key self.driver = driver - arr_dict = self.load_arrays(allow_unfound=allow_unfound,quiet=quiet) + arr_dict = self.load_arrays(allow_unfound=allow_unfound, quiet=quiet) super().__init__(arr_dict, inputs, targets, *args, **kwargs, quiet=quiet, allow_unfound=allow_unfound) - self.restarter = self.make_restarter(directory, inputs, targets, *args, files=files, quiet=quiet, - species_key=species_key, **kwargs) + self.restarter = self.make_restarter(directory, inputs, targets, *args, files=files, quiet=quiet, species_key=species_key, **kwargs) def load_arrays(self, allow_unfound=False, quiet=False): @@ -257,24 +274,29 @@ def load_arrays(self, allow_unfound=False, quiet=False): return arr_dict -def write_h5(database: Database, split: str = None, file: Path = None, species_key: str = 'species', overwrite=False): +def write_h5( + database: Database, + split: str = None, + file: Path = None, + species_key: str = "species", + overwrite: bool = False, +) -> dict: """ - :param database: database to get + :param database: Database to use :param split: str, None, or True; selects data split to save. If None, contents of arr_dict are used. If True, save all splits and save split masks as well. - :param file: where to save the database. + :param file: where to save the database. if None, does not save the file. :param species_key: the key used for system contents (padding and chemical formulas) :param overwrite: boolean; enables over-writing of h5 file. - :return: dictionary of ANI-style systems. + :return: dictionary of pyanitools-format systems. """ if split is True: database = database.write_npz("", record_split_masks=True, return_only=True) - print("writenpz", database.keys()) elif split in database.splits: database = database.splits[split] - database = {k: v.to('cpu').numpy() for k,v in database.items()} + database = {k: v.to("cpu").numpy() for k, v in database.items()} elif split is None: database = database.arr_dict else: @@ -297,10 +319,8 @@ def write_h5(database: Database, split: str = None, file: Path = None, species_k n_atoms_max = db_species.shape[1] # determine which keys have second shape of N atoms - is_atom_var = { - k: (len(k_arr.shape) > 1) and (k_arr.shape[1] == n_atoms_max) for k, k_arr in database.items() - } - del (is_atom_var[species_key]) # species handled separately + is_atom_var = {k: (len(k_arr.shape) > 1) and (k_arr.shape[1] == n_atoms_max) for k, k_arr in database.items()} + del is_atom_var[species_key] # species handled separately # Create the data dictionary # Maps hashes of system chemical formulas to dictionaries of system information. @@ -343,7 +363,7 @@ def write_h5(database: Database, split: str = None, file: Path = None, species_k mol[k] = np.asarray(mol[k]) if np.issubdtype(mol[k].dtype, np.unicode_): - mol[k] = [el.encode('utf-8') for el in list(mol[k])] + mol[k] = [el.encode("utf-8") for el in list(mol[k])] mol[k] = np.array(mol[k]) # Store data if packer is not None: diff --git a/hippynn/databases/ondisk.py b/hippynn/databases/ondisk.py index fce60bdc..6fce1d37 100644 --- a/hippynn/databases/ondisk.py +++ b/hippynn/databases/ondisk.py @@ -14,7 +14,7 @@ class DirectoryDatabase(Database, Restartable): """ - Database stored as NPY files in a diectory. + Database stored as NPY files in a directory. :param directory: directory path where the files are stored :param name: prefix for the arrays. diff --git a/hippynn/experiment/__init__.py b/hippynn/experiment/__init__.py index 3a222e9b..f42a8091 100644 --- a/hippynn/experiment/__init__.py +++ b/hippynn/experiment/__init__.py @@ -12,9 +12,10 @@ from .assembly import assemble_for_training from .routines import setup_and_train, setup_training, train_model, test_model, SetupParams +from .serialization import load_checkpoint, load_checkpoint_from_cwd, load_model_from_cwd - -__all__ = ["assemble_for_training", "setup_and_train", "setup_training", "train_model", "test_model", "SetupParams",] +__all__ = ["assemble_for_training", "setup_and_train", "setup_training", "train_model", "test_model", "SetupParams", + "load_checkpoint", "load_checkpoint_from_cwd", "load_model_from_cwd"] try: from .lightning_trainer import HippynnLightningModule diff --git a/hippynn/experiment/lightning_trainer.py b/hippynn/experiment/lightning_trainer.py index ead8eb57..3b6e1d52 100644 --- a/hippynn/experiment/lightning_trainer.py +++ b/hippynn/experiment/lightning_trainer.py @@ -31,6 +31,9 @@ class HippynnLightningModule(pl.LightningModule): + """ + A pytorch lightning module for running a hippynn experiment. + """ def __init__( self, model: GraphModule, @@ -84,6 +87,15 @@ def __init__( @classmethod def from_experiment_setup(cls, training_modules: TrainingModules, database: Database, setup_params: SetupParams, **kwargs): + """ + Create a lightning module using the same arguments as for :func:`hippynn.experiment.setup_and_train`. + + :param training_modules: + :param database: + :param setup_params: + :param kwargs: + :return: + """ training_modules, controller, metric_tracker = setup_training(training_modules, setup_params) return cls.from_train_setup(training_modules, database, controller, metric_tracker, **kwargs) @@ -98,6 +110,19 @@ def from_train_setup( batch_callbacks=None, **kwargs, ): + """ + Create a lightning module from the same arguments as for :func:`hippynn.experiment.train_model`. + + :param training_modules: + :param database: + :param controller: + :param metric_tracker: + :param callbacks: + :param batch_callbacks: + :param kwargs: + :return: + """ + model, loss, evaluator = training_modules @@ -131,6 +156,11 @@ def from_train_setup( return trainer, HippynnDataModule(database, controller.batch_size) def on_save_checkpoint(self, checkpoint) -> None: + """ + + :param checkpoint: + :return: + """ # Note to future developers: # trainer.log_dir property needs to be called on all ranks! This is weird but important; @@ -163,6 +193,16 @@ def on_save_checkpoint(self, checkpoint) -> None: @classmethod def load_from_checkpoint(cls, checkpoint_path, map_location=None, structure_file=None, hparams_file=None, strict=True, **kwargs): + """ + + :param checkpoint_path: + :param map_location: + :param structure_file: + :param hparams_file: + :param strict: + :param kwargs: + :return: + """ if structure_file is None: # Assume checkpoint_path is like /version_/checkpoints/.chkpt @@ -178,11 +218,20 @@ def load_from_checkpoint(cls, checkpoint_path, map_location=None, structure_file ) def on_load_checkpoint(self, checkpoint) -> None: + """ + + :param checkpoint: + :return: + """ cstate = checkpoint.pop("controller_state") self.controller.load_state_dict(cstate) return def configure_optimizers(self): + """ + + :return: + """ scheduler_list = [] for s in self.scheduler_list: @@ -201,14 +250,24 @@ def configure_optimizers(self): return optimizer_list, scheduler_list def on_train_epoch_start(self): + """ + + :return: + """ for optimizer in self.optimizer_list: print_lr(optimizer, print_=self.print) self.print("Batch size:", self.trainer.train_dataloader.batch_size) def training_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + :return: + """ batch_inputs = batch[: self.n_inputs] - batch_targets = batch[-self.n_targets :] + batch_targets = batch[-self.n_targets:] batch_model_outputs = self.model(*batch_inputs) batch_train_loss = self.loss(*batch_model_outputs, *batch_targets)[0] @@ -232,9 +291,21 @@ def _eval_step(self, batch, batch_idx): return batch_predictions def validation_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + :return: + """ return self._eval_step(batch, batch_idx) def test_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + :return: + """ return self._eval_step(batch, batch_idx) def _eval_epoch_end(self, prefix): @@ -259,10 +330,18 @@ def _eval_epoch_end(self, prefix): return def on_validation_epoch_end(self): + """ + + :return: + """ self._eval_epoch_end(prefix="valid_") return def on_test_epoch_end(self): + """ + + :return: + """ self._eval_epoch_end(prefix="test_") return @@ -326,10 +405,18 @@ def _eval_end(self, prefix, when=None) -> None: return def on_validation_end(self): + """ + + :return: + """ self._eval_end(prefix="valid_") return def on_test_end(self): + """ + + :return: + """ self._eval_end(prefix="test_", when="test") return @@ -360,10 +447,22 @@ def __init__(self, database: Database, batch_size): self.batch_size = batch_size def train_dataloader(self): + """ + + :return: + """ return self.database.make_generator("train", "train", self.batch_size) def val_dataloader(self): + """ + + :return: + """ return self.database.make_generator("valid", "eval", self.batch_size) def test_dataloader(self): + """ + + :return: + """ return self.database.make_generator("test", "eval", self.batch_size) diff --git a/hippynn/experiment/routines.py b/hippynn/experiment/routines.py index 84faa5c0..ed2e7746 100644 --- a/hippynn/experiment/routines.py +++ b/hippynn/experiment/routines.py @@ -21,6 +21,8 @@ from .. import tools from .assembly import TrainingModules from .step_functions import get_step_function +from ..databases import Database + from .. import custom_kernels @@ -101,7 +103,7 @@ def __post_init__(self): def setup_and_train( training_modules: TrainingModules, - database, + database: Database, setup_params: SetupParams, store_all_better=False, store_best=True, diff --git a/hippynn/layers/pairs/filters.py b/hippynn/layers/pairs/filters.py index d1aebedc..653cb505 100644 --- a/hippynn/layers/pairs/filters.py +++ b/hippynn/layers/pairs/filters.py @@ -4,10 +4,11 @@ from .open import _PairIndexer class FilterDistance(_PairIndexer): - """ Filters a list of tensors in *pair_lists by distance. + """ + Filters a list of tensors in pair_tensors by distance. pair_dist is first positional argument. - :param _PairIndexer: FilterDistance subclasses _PairIndexer so that the + FilterDistance subclasses _PairIndexer so that the FilterPairIndexers behave as regular PairIndexers. """ diff --git a/hippynn/molecular_dynamics/__init__.py b/hippynn/molecular_dynamics/__init__.py index 3bcc1722..622cb6ce 100644 --- a/hippynn/molecular_dynamics/__init__.py +++ b/hippynn/molecular_dynamics/__init__.py @@ -2,4 +2,7 @@ Molecular dynamics driver with great flexibility and customizability regarding which quantities which are evolved and what algorithms are used to evolve them. Calls a hippynn `Predictor` on current state during each MD step. """ -from .md import * \ No newline at end of file +from .md import MolecularDynamics, Variable, NullUpdater, VelocityVerlet, LangevinDynamics + + +__all__ = ["MolecularDynamics", "Variable", "NullUpdater", "VelocityVerlet", "LangevinDynamics"] diff --git a/hippynn/molecular_dynamics/md.py b/hippynn/molecular_dynamics/md.py index 8752990c..1375fd77 100644 --- a/hippynn/molecular_dynamics/md.py +++ b/hippynn/molecular_dynamics/md.py @@ -10,6 +10,7 @@ from ..graphs import Predictor from ..layers.pairs.periodic import wrap_systems_torch + class Variable: """ Tracks the state of a quantity (eg. position, cell, species, diff --git a/hippynn/tools.py b/hippynn/tools.py index df1507cb..15e4768e 100644 --- a/hippynn/tools.py +++ b/hippynn/tools.py @@ -11,7 +11,7 @@ from . import settings -class teed_file_output: +class TeedFileOutput: def __init__(self, *streams): self.streams = streams @@ -42,8 +42,8 @@ def log_terminal(file, *args, **kwargs): file = open(file, *args, **kwargs) else: close_on_exit = False - teed_stderr = teed_file_output(file, sys.stderr) - teed_stdout = teed_file_output(file, sys.stdout) + teed_stderr = TeedFileOutput(file, sys.stderr) + teed_stdout = TeedFileOutput(file, sys.stdout) with contextlib.redirect_stderr(teed_stderr): with contextlib.redirect_stdout(teed_stdout): try: @@ -102,6 +102,16 @@ def active_directory(dirname, create=None): def progress_bar(iterable, *args, **kwargs): + """ + Wrap an iterable in a progress bar according to hippynn's current progress bar settings. + + for args and kwargs, see tqdm documentation. + + :param iterable: + :param args: + :param kwargs: + :return: + """ if settings.PROGRESS is None: return iterable else: @@ -166,9 +176,12 @@ def unsqueeze_multiple(tensor, dims: tuple): tensor = tensor.unsqueeze(d) dims = tuple(d+1 for d in rest) return tensor + + def np_of_torchdefaultdtype(): return torch.ones(1, dtype=torch.get_default_dtype()).numpy().dtype + def is_equal_state_dict(d1, d2, raise_where=False): """ Checks if two pytorch state dictionaries are equal. Calls itself recursively From 144c160ea64122ae1666fe95de9ecd9380c2e5d8 Mon Sep 17 00:00:00 2001 From: "Michael G. Taylor" <119455260+mgt16-LANL@users.noreply.github.com> Date: Fri, 13 Sep 2024 14:24:52 -0600 Subject: [PATCH 6/6] Bugfixes to custom kernels, doc updates, changelog updates (#102) * Update docs and changelog 1. Replace all occurrences of `restore_db` with `restart_db`. A note is added to the documentation to reflect this change. 2. Update changelog for breaking changes of `restart_db` and `make_trainvalidtest_split`. * Fix typos and rephrase the note * Check cuda - torch.cuda.get_device_capability - Only run triton check if cuda is available. * update a lot of settings and doc releated things. fix custom kernel handling * update changelog, revert ipynb figure changes --------- Co-authored-by: Xinyang Li Co-authored-by: Nicholas Lubbers --- CHANGELOG.rst | 16 +- docs/source/conf.py | 9 +- docs/source/examples/mliap_unified.rst | 2 +- docs/source/examples/restarting.rst | 8 + docs/source/user_guide/settings.rst | 1 + examples/lammps/hippynn_lammps_example.ipynb | 2 +- hippynn/__init__.py | 17 +- hippynn/_settings_setup.py | 218 ++++++++++++------- hippynn/custom_kernels/__init__.py | 155 +++++++++---- hippynn/databases/database.py | 2 +- hippynn/experiment/device.py | 2 - hippynn/experiment/serialization.py | 1 - hippynn/layers/pairs/dispatch.py | 9 +- hippynn/molecular_dynamics/__init__.py | 7 +- hippynn/molecular_dynamics/md.py | 84 +++---- pyproject.toml | 2 +- tests/progress_settings.py | 23 ++ 17 files changed, 355 insertions(+), 203 deletions(-) create mode 100644 tests/progress_settings.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6842e0d5..f9d40038 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,8 +3,16 @@ Breaking changes: ----------------- -- set_e0_values has been renamed hierarchical_energy_initialization. The old name is - still provided but deprecated, and will be removed. +- ``set_e0_values`` has been renamed to ``hierarchical_energy_initialization``. + The old name is still provided but deprecated, and will be removed. +- The argument ``restore_db`` has been renamed to ``restart_db``. The affected + functions are ``load_checkpoint``, ``load_checkpoint_from_cwd``, and + ``restore_checkpoint``. +- ``database.make_trainvalidtest_split`` now only takes keyword arguments to + avoid confusions. Use ``make_trainvalidtest_split(test_size=a, valid_size=b)`` + instead of ``make_trainvalidtest_split(a, b)``. +- Invalid custom kernel specifications are now errors rather than warnings. + New Features: ------------- @@ -22,6 +30,7 @@ New Features: - Added tool to drastically simplify creating ensemble models. The ensemblized graphs are compatible with molecular dynamics codes such ASE and LAMMPS. - Added the ability to weight different systems/atoms/bonds in a loss function. +- Added new function to reload library settings. Improvements: @@ -35,6 +44,9 @@ Improvements: using a library setting. - Provide tunable regularization of HIP-NN-TS with an epsilon parameter, and set the default to use a better value for epsilon. +- Improved detection of valid custom kernel implementation. +- Improved computational efficiency of HIP-NN-TS network. + Bug Fixes: diff --git a/docs/source/conf.py b/docs/source/conf.py index 8e707bdf..945fc9e7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -23,6 +23,7 @@ # The full version, including alpha/beta/rc tags import hippynn + release = hippynn.__version__ # -- General configuration --------------------------------------------------- @@ -31,7 +32,6 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = ["sphinx.ext.autodoc", "sphinx_rtd_theme", "sphinx.ext.viewcode"] -add_module_names = False # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -45,12 +45,13 @@ "no-show-inheritance": True, "special-members": "__init__", } -autodoc_member_order = 'bysource' +autodoc_member_order = "bysource" # The following are highly optional, so we mock them for doc purposes. -autodoc_mock_imports = ["pyanitools", "seqm", "schnetpack", "cupy", "lammps", "numba", "triton", "pytorch_lightning", 'triton', 'scipy'] - +# TODO: Can we programmatically get these from our list of optional dependencies? +autodoc_mock_imports = ["ase", "h5py", "seqm", "schnetpack", "cupy", "lammps", "numba", "triton", "pytorch_lightning", 'scipy'] +add_module_names = False # -- Options for HTML output ------------------------------------------------- diff --git a/docs/source/examples/mliap_unified.rst b/docs/source/examples/mliap_unified.rst index 4d627ee6..065608cb 100644 --- a/docs/source/examples/mliap_unified.rst +++ b/docs/source/examples/mliap_unified.rst @@ -11,7 +11,7 @@ species atomic symbols (whose order must agree with the order of the training hy Example:: - bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False) + bundle = load_checkpoint_from_cwd(map_location="cpu", restart_db=False) model = bundle["training_modules"].model energy_node = model.node_from_name("HEnergy") unified = MLIAPInterface(energy_node, ["Al"], model_device=torch.device("cuda")) diff --git a/docs/source/examples/restarting.rst b/docs/source/examples/restarting.rst index 75d14aa3..18a00949 100644 --- a/docs/source/examples/restarting.rst +++ b/docs/source/examples/restarting.rst @@ -43,6 +43,14 @@ or to use the default filenames and load from the current directory:: check = load_checkpoint_from_cwd() train_model(**check, callbacks=None, batch_callbacks=None) +.. note:: + In release 0.0.4, the ``restore_db`` argument has been renamed to + ``restart_db`` for internal consistence. ``restore_db`` in all scripts using + `hippynn > 0.0.3` should be replaced with ``restart_db``. The affected + functions are ``load_checkpoint``, ``load_checkpoint_from_cwd``, and + ``restore_checkpoint``. If `hippynn <= 0.0.3` is used, please keep the + original ``restore_db`` keyword. + If all you want to do is use a previously trained model, here is how to load the model only:: from hippynn.experiment.serialization import load_model_from_cwd diff --git a/docs/source/user_guide/settings.rst b/docs/source/user_guide/settings.rst index a4c8fcef..a7837531 100644 --- a/docs/source/user_guide/settings.rst +++ b/docs/source/user_guide/settings.rst @@ -11,6 +11,7 @@ There are four possible sources for settings. 3. A file specified by the environment variable `HIPPYNN_LOCAL_RC_FILE` which is treated the same as the user rc file. 4. Environment variables prefixed by ``HIPPYNN_``, e.g. ``HIPPYNN_DEFAULT_PLOT_FILETYPE``. +5. Arguments passed to :func:`hippynn.reload_settings`. These three sources are checked in order, so that values in later sources overwrite values found in earlier sources. diff --git a/examples/lammps/hippynn_lammps_example.ipynb b/examples/lammps/hippynn_lammps_example.ipynb index f42f31a3..12ed548f 100644 --- a/examples/lammps/hippynn_lammps_example.ipynb +++ b/examples/lammps/hippynn_lammps_example.ipynb @@ -38,7 +38,7 @@ "\n", "try:\n", " with active_directory(\"./TEST_INP_MODEL\", create=False):\n", - " bundle = load_checkpoint_from_cwd(map_location='cpu',restore_db=False)\n", + " bundle = load_checkpoint_from_cwd(map_location='cpu',restart_db=False)\n", "except FileNotFoundError:\n", " raise FileNotFoundError(\"Model not found, run lammps_example.py first!\")\n", "\n", diff --git a/hippynn/__init__.py b/hippynn/__init__.py index bf59fa7d..3b451ac9 100644 --- a/hippynn/__init__.py +++ b/hippynn/__init__.py @@ -2,17 +2,21 @@ The hippynn python package. +.. autodata:: settings + :no-value: + + """ from . import _version __version__ = _version.get_versions()['version'] # Configuration settings -from ._settings_setup import settings +from ._settings_setup import settings, reload_settings # Pytorch modules from . import layers -from . import networks # wait this one is different from the other one. +from . import networks # Graph abstractions from . import graphs @@ -40,3 +44,12 @@ from . import tools from .tools import active_directory, log_terminal + +# The order is adjusted to put functions after objects in the documentation. +_dir = dir() +_lowerdir = [x for x in _dir if x[0].lower() == x[0]] +_upperdir = [x for x in _dir if x[0].upper() == x[0]] +__all__ = _lowerdir + _upperdir +del _dir, _lowerdir, _upperdir + +__all__ = [x for x in __all__ if not x.startswith("_")] diff --git a/hippynn/_settings_setup.py b/hippynn/_settings_setup.py index c72eb94f..675fae91 100644 --- a/hippynn/_settings_setup.py +++ b/hippynn/_settings_setup.py @@ -10,46 +10,79 @@ import warnings import os import configparser +from typing import Union from distutils.util import strtobool from types import SimpleNamespace from functools import partial +# Significant strings +SECTION_NAME = "GLOBALS" +SETTING_PREFIX = "HIPPYNN_" +LOCAL_RC_FILE_KEY = "LOCAL_RC_FILE" -try: - from tqdm.contrib import tqdm_auto +# Globals +DEFAULT_PROGRESS = None # this gets set to a tqdm object if possible +TQDM_PROGRESS = None # the best progress bar from tqdm, if available. - TQDM_PROGRESS = tqdm_auto -except ImportError: +def setup_tqdm(): + global TQDM_PROGRESS + global DEFAULT_PROGRESS try: - from tqdm import tqdm + from tqdm.contrib import tqdm_auto - TQDM_PROGRESS = tqdm + TQDM_PROGRESS = tqdm_auto except ImportError: - TQDM_PROGRESS = None - -if TQDM_PROGRESS is not None: - DEFAULT_PROGRESS = partial(TQDM_PROGRESS, mininterval=1.0, leave=False) -else: - DEFAULT_PROGRESS = None -### Progress handlers - -def progress_handler(prog_str): - if prog_str == "tqdm": - return DEFAULT_PROGRESS - elif prog_str.lower() == "none": - return None - else: try: - prog_float = float(prog_str) - return partial(TQDM_PROGRESS, mininterval=prog_float, leave=False) - except: - pass - warnings.warn(f"Unrecognized progress setting: '{prog_str}'. Setting to none.") + from tqdm import tqdm + + TQDM_PROGRESS = tqdm + except ImportError: + TQDM_PROGRESS = None + + if TQDM_PROGRESS is not None: + DEFAULT_PROGRESS = partial(TQDM_PROGRESS, mininterval=1.0, leave=False) + else: + DEFAULT_PROGRESS = None + +# Setting handlers: Take an input str or other value and return the appropriate value. + +def progress_handler(prog_setting: Union[str, float, bool, None]): + """ + Function for handling the progress bar settings. + + :param prog_setting: + :return: + """ + if TQDM_PROGRESS is None: + setup_tqdm() + + if prog_setting in (True, False, None): + prog_setting = { + True: "tqdm", + False: "none", + None: "none", + }[prog_setting] + + if isinstance(prog_setting, str): + prog_setting = prog_setting.lower() + if prog_setting == "tqdm": + return DEFAULT_PROGRESS + elif prog_setting.lower() == "none": + return None + + prog_setting = float(prog_setting) # Trigger error if not floatable. + + return partial(TQDM_PROGRESS, mininterval=prog_setting, leave=False) + def kernel_handler(kernel_string): + """ + :param kernel_string: + :return: + """ kernel_string = kernel_string.lower() kernel = { @@ -61,76 +94,107 @@ def kernel_handler(kernel_string): }.get(kernel_string, kernel_string) if kernel not in [True, False, "auto", "triton", "cupy", "numba"]: - warnings.warn(f"Unrecognized custom kernel option: {kernel_string}. Setting custom kernels to 'auto'") - kernel = "auto" + warnings.warn(f"Unexpected custom kernel setting: {kernel_string}.", stacklevel=3) return kernel -# keys: defaults, types, and handlers -default_settings = { - "PROGRESS": (DEFAULT_PROGRESS, progress_handler), +def bool_or_strtobool(key: Union[bool, str]): + if isinstance(key, bool): + return key + else: + return strtobool(key) + + +# keys: defaults, types, and handlers. +DEFAULT_SETTINGS = { + "PROGRESS": ('tqdm', progress_handler), "DEFAULT_PLOT_FILETYPE": (".pdf", str), - "TRANSPARENT_PLOT": (False, strtobool), - "DEBUG_LOSS_BROADCAST": (False, strtobool), - "DEBUG_GRAPH_EXECUTION": (False, strtobool), - "DEBUG_NODE_CREATION": (False, strtobool), - "DEBUG_AUTOINDEXING": (False, strtobool), + "TRANSPARENT_PLOT": (False, bool_or_strtobool), + "DEBUG_LOSS_BROADCAST": (False, bool_or_strtobool), + "DEBUG_GRAPH_EXECUTION": (False, bool_or_strtobool), + "DEBUG_NODE_CREATION": (False, bool_or_strtobool), + "DEBUG_AUTOINDEXING": (False, bool_or_strtobool), "USE_CUSTOM_KERNELS": ("auto", kernel_handler), - "WARN_LOW_DISTANCES": (True, strtobool), - "TIMEPLOT_AUTOSCALING": (True, strtobool), + "WARN_LOW_DISTANCES": (True, bool_or_strtobool), + "TIMEPLOT_AUTOSCALING": (True, bool_or_strtobool), "PYTORCH_GPU_MEM_FRAC": (1.0, float), } -settings = SimpleNamespace(**{k: default for k, (default, handler) in default_settings.items()}) +INITIAL_SETTINGS = {k: handler(default) for k, (default, handler) in DEFAULT_SETTINGS.items()} + +settings = SimpleNamespace(**INITIAL_SETTINGS) settings.__doc__ = """ Values for the current hippynn settings. See :doc:`/user_guide/settings` for a description. """ -config_sources = {} # Dictionary of configuration variable sources mapping to dictionary of configuration. -# We add to this dictionary in order of application -SECTION_NAME = "GLOBALS" +def reload_settings(**kwargs): + """ + Attempt to reload the hippynn library settings. -rc_name = os.path.expanduser("~/.hippynnrc") -if os.path.exists(rc_name) and os.path.isfile(rc_name): - config = configparser.ConfigParser(inline_comment_prefixes="#") - config.read(rc_name) - if SECTION_NAME not in config: - warnings.warn(f"Config file {rc_name} does not contain a {SECTION_NAME} section and will be ignored!") - else: - config_sources["~/.hippynnrc"] = config[SECTION_NAME] + Settings sources are, in order from least to greatest priority: + - Default values + - The file `~/.hippynnrc`, which is a standard python config file which contains + variables under the section name [GLOBALS]. + - A file specified by the environment variable `HIPPYNN_LOCAL_RC_FILE` + which is treated the same as the user rc file. + - Environment variables prefixed by ``HIPPYNN_``, e.g. ``HIPPYNN_DEFAULT_PLOT_FILETYPE``. + - Keyword arguments passed to this function. -SETTING_PREFIX = "HIPPYNN_" -hippynn_environment_variables = { - k.replace(SETTING_PREFIX, ""): v for k, v in os.environ.items() if k.startswith(SETTING_PREFIX) -} -LOCAL_RC_FILE_KEY = "LOCAL_RC_FILE" + :param kwargs: explicit settings to change. + + :return: + """ + # Developer note: this function modifies the module-scope `settings` directly. -if LOCAL_RC_FILE_KEY in hippynn_environment_variables: - local_rc_fname = hippynn_environment_variables.pop(LOCAL_RC_FILE_KEY) - if os.path.exists(local_rc_fname) and os.path.isfile(local_rc_fname): - local_config = configparser.ConfigParser() - local_config.read(local_rc_fname) - if SECTION_NAME not in local_config: - warnings.warn(f"Config file {local_rc_fname} does not contain a {SECTION_NAME} section and will be ignored!") + config_sources = {} # Dictionary of configuration variable sources mapping to dictionary of configuration. + # We add to this dictionary in order of application + + rc_name = os.path.expanduser("~/.hippynnrc") + if os.path.exists(rc_name) and os.path.isfile(rc_name): + config = configparser.ConfigParser(inline_comment_prefixes="#") + config.read(rc_name) + if SECTION_NAME not in config: + warnings.warn(f"Config file {rc_name} does not contain a {SECTION_NAME} section and will be ignored!") else: - config_sources[LOCAL_RC_FILE_KEY] = local_config[SECTION_NAME] - else: - warnings.warn(f"Local configuration file {local_rc_fname} not found.") - -config_sources["environment variables"] = hippynn_environment_variables - -for sname, source in config_sources.items(): - for key, value in source.items(): - key = key.upper() - if key in default_settings: - default, handler = default_settings[key] - try: - setattr(settings, key, handler(value)) - except Exception as ee: - raise ValueError(f"Value {value} for setting {key} is invalid") from ee + config_sources["~/.hippynnrc"] = config[SECTION_NAME] + + hippynn_environment_variables = { + k.replace(SETTING_PREFIX, ""): v for k, v in os.environ.items() if k.startswith(SETTING_PREFIX) + } + + if LOCAL_RC_FILE_KEY in hippynn_environment_variables: + local_rc_fname = hippynn_environment_variables.pop(LOCAL_RC_FILE_KEY) + if os.path.exists(local_rc_fname) and os.path.isfile(local_rc_fname): + local_config = configparser.ConfigParser() + local_config.read(local_rc_fname) + if SECTION_NAME not in local_config: + warnings.warn(f"Config file {local_rc_fname} does not contain a {SECTION_NAME} section and will be ignored!") + else: + config_sources[LOCAL_RC_FILE_KEY] = local_config[SECTION_NAME] else: - warnings.warn(f"Configuration source {sname} contains invalid variables ({key}). They will not be used.") + warnings.warn(f"Local configuration file {local_rc_fname} not found.") + + config_sources["environment variables"] = hippynn_environment_variables + config_sources["kwargs"] = kwargs.copy() + + for sname, source in config_sources.items(): + for key, value in source.items(): + key = key.upper() + if key in DEFAULT_SETTINGS: + default, handler = DEFAULT_SETTINGS[key] + try: + setattr(settings, key, handler(value)) + except Exception as ee: + raise ValueError(f"Value {value} for setting {key} is invalid") from ee + else: + warnings.warn(f"Configuration source {sname} contains invalid variables ({key}). These will be ignored.") + + return settings + + +reload_settings() + diff --git a/hippynn/custom_kernels/__init__.py b/hippynn/custom_kernels/__init__.py index 8ef8953b..ab9bdcce 100644 --- a/hippynn/custom_kernels/__init__.py +++ b/hippynn/custom_kernels/__init__.py @@ -1,64 +1,90 @@ """ Custom Kernels for hip-nn interaction sum. -This module provides implementations in pytorch, numba, and cupy. +This module provides implementations in pytorch, numba, cupy, and triton. Pytorch implementations take extra memory, but launch faster than numba kernels. - Numba kernels use far less memory, but do come with some launching overhead on GPUs. - Cupy kernels only work on the GPU, but are faster than numba. Cupy kernels require numba for CPU operations. +Triton custom kernels only work on the GPU, and are generaly faster than CUPY. +Triton kernels revert to numba or pytorch as available on module import. + +On import, this module attempts to set the custom kernels as specified by the +user in hippynn.settings. + +.. py:data:: CUSTOM_KERNELS_AVAILABLE + :type: list[str] + + The set of custom kernels available, based on currently installed packages and hardware. + +.. py:data:: CUSTOM_KERNELS_ACTIVE + :type: str + + The currently active implementation of custom kernels. + """ import warnings from typing import Union - +import torch from .. import settings from . import autograd_wrapper, env_pytorch -CUSTOM_KERNELS_AVAILABLE = [] -try: - import numba - CUSTOM_KERNELS_AVAILABLE.append("numba") -except ImportError: +class CustomKernelError(Exception): pass -try: - import cupy - if "numba" not in CUSTOM_KERNELS_AVAILABLE: - warnings.warn("Cupy was found, but numba was not. Cupy custom kernels not available.") - else: - CUSTOM_KERNELS_AVAILABLE.append("cupy") -except ImportError: - pass +def populate_custom_kernel_availability(): + """ + Check available imports and populate the list of available custom kernels. -try: - import triton - import torch - device_capability = torch.cuda.get_device_capability() - if device_capability[0] > 6: - CUSTOM_KERNELS_AVAILABLE.append("triton") - else: - warnings.warn( - f"Triton found but not supported by GPU's compute capability: {device_capability}" - ) -except ImportError: - pass + This function changes the global variable custom_kernels.CUSTOM_KERNELS_AVAILABLE - -except ImportError: - pass + :return: + """ -if not CUSTOM_KERNELS_AVAILABLE: - warnings.warn( - "Triton, cupy and numba are not available: Custom kernels will be disabled and performance maybe be degraded.") + # check order for kernels is numba, cupy, triton. + global CUSTOM_KERNELS_AVAILABLE -CUSTOM_KERNELS_ACTIVE = False + CUSTOM_KERNELS_AVAILABLE = [] -envsum, sensesum, featsum = None, None, None + try: + import numba + + CUSTOM_KERNELS_AVAILABLE.append("numba") + except ImportError: + pass + + if torch.cuda.is_available(): + try: + import cupy + + if "numba" not in CUSTOM_KERNELS_AVAILABLE: + warnings.warn("Cupy was found, but numba was not. Cupy custom kernels not available.") + else: + CUSTOM_KERNELS_AVAILABLE.append("cupy") + except ImportError: + pass + try: + import triton + if torch.cuda.is_available(): + device_capability = torch.cuda.get_device_capability() + if device_capability[0] > 6: + CUSTOM_KERNELS_AVAILABLE.append("triton") + else: + warnings.warn( + f"Triton found but not supported by GPU's compute capability: {device_capability}" + ) + except ImportError: + pass + + + if not CUSTOM_KERNELS_AVAILABLE: + warnings.warn( + "Triton, cupy and numba are not available: Custom kernels will be disabled and performance maybe be degraded.") + return CUSTOM_KERNELS_AVAILABLE def _check_numba(): import numba.cuda @@ -86,12 +112,18 @@ def _check_cupy(): if not cupy.cuda.is_available(): if torch.cuda.is_available(): warnings.warn("cupy.cuda.is_available() returned False: Custom kernels will fail on GPU tensors.") - + return def set_custom_kernels(active: Union[bool, str] = True): """ Activate or deactivate custom kernels for interaction. + This function changes the global variables: + - custom_kernels.envsum + - custom_kernels.sensum + - custom_kernels.featsum + - custom_kernels.CUSTOM_KERNELS_ACTIVE + :param active: If true, set custom kernels to the best available. If False, turn them off and default to pytorch. If "triton", "numba" or "cupy", use those implementations explicitly. If "auto", use best available. :return: None @@ -101,17 +133,20 @@ def set_custom_kernels(active: Union[bool, str] = True): if isinstance(active, str): active = active.lower() - if active not in [True, False, "triton", "numba", "cupy", "pytorch", "auto"]: - raise ValueError(f"Unrecognized custom kernel implementation: {active}") + if active not in _POSSIBLE_CUSTOM_KERNELS: + raise CustomKernelError(f"Unrecognized custom kernel implementation: {active}") - active_map = {"auto": True, "pytorch": False} if not CUSTOM_KERNELS_AVAILABLE: - if active == "auto" or active == "pytorch": + if active in ("auto", "pytorch"): # These are equivalent to "false" when custom kernels are not available. active = False elif active: - raise RuntimeError( - "Triton, numba and cupy were not found. Custom kernels are not available, but they were required by library settings.") + # The user explicitly set a custom kernel implementation or just True. + raise CustomKernelError( + "Triton, numba and cupy were not found." + + f"Custom kernels are not available, but they were required by library setting: {active}") else: + # If custom kernels are available, then "auto" and "pytorch" revert to bool values. + active_map = {"auto": True, "pytorch": False} active = active_map.get(active, active) # Handle fallback to pytorch kernels. @@ -124,7 +159,7 @@ def set_custom_kernels(active: Union[bool, str] = True): # Select custom kernel implementation if not CUSTOM_KERNELS_AVAILABLE: - raise RuntimeError("Numba was not found. Custom kernels are not available.") + raise CustomKernelError("Numba was not found. Custom kernels are not available.") if active is True: if "triton" in CUSTOM_KERNELS_AVAILABLE: @@ -135,18 +170,20 @@ def set_custom_kernels(active: Union[bool, str] = True): active = "numba" if active not in CUSTOM_KERNELS_AVAILABLE: - raise RuntimeError(f"Unavailable custom kernel implementation: {active}") + raise CustomKernelError(f"Unavailable custom kernel implementation: {active}") if active == "triton": from .env_triton import envsum as triton_envsum, sensesum as triton_sensesum, featsum as triton_featsum envsum, sensesum, featsum = autograd_wrapper.wrap_envops(triton_envsum, triton_sensesum, triton_featsum) + elif active == "cupy": _check_numba() _check_cupy() from .env_cupy import cupy_envsum, cupy_featsum, cupy_sensesum envsum, sensesum, featsum = autograd_wrapper.wrap_envops(cupy_envsum, cupy_sensesum, cupy_featsum) + elif active == "numba": _check_numba() from .env_numba import new_envsum, new_featsum, new_sensesum @@ -157,11 +194,33 @@ def set_custom_kernels(active: Union[bool, str] = True): # We shouldn't get here except possibly mid-development, but just in case: # if you add a custom kernel implementation remember to add to this # dispatch block. - raise ValueError(f"Unknown Implementation: {active}") + raise CustomKernelError(f"Unknown Implementation: {active}") CUSTOM_KERNELS_ACTIVE = active + return +CUSTOM_KERNELS_AVAILABLE = [] + +_POSSIBLE_CUSTOM_KERNELS = [True, False, "triton", "numba", "cupy", "pytorch", "auto"] try_custom_kernels = settings.USE_CUSTOM_KERNELS -set_custom_kernels(try_custom_kernels) + +CUSTOM_KERNELS_ACTIVE = None + +envsum, sensesum, featsum = None, None, None + +try: + populate_custom_kernel_availability() + set_custom_kernels(try_custom_kernels) +except CustomKernelError as eee: + raise +except Exception as ee: + warnings.warn(f"Custom kernels are disabled due to an expected error:\n\t{ee}", stacklevel=2) + del ee + + envsum = env_pytorch.envsum + sensesum = env_pytorch.sensesum + featsum = env_pytorch.featsum + CUSTOM_KERNELS_ACTIVE = False + del try_custom_kernels diff --git a/hippynn/databases/database.py b/hippynn/databases/database.py index 15bdd93b..a21301a6 100644 --- a/hippynn/databases/database.py +++ b/hippynn/databases/database.py @@ -347,7 +347,7 @@ def make_automatic_splits(self, split_prefix=None, dry_run=False): it fails pretty strictly. :param split_prefix: None, use default. - If otherwise, use this prefix to determine what arrays are masks. + If otherwise, use this prefix to determine what arrays are masks. :param dry_run: Only validate that existing split masks are correct; don't perform splitting. :return: """ diff --git a/hippynn/experiment/device.py b/hippynn/experiment/device.py index b648ba65..9f144dcc 100644 --- a/hippynn/experiment/device.py +++ b/hippynn/experiment/device.py @@ -22,9 +22,7 @@ def set_devices( Evaluation loss is performed on CPU. :param model: current model on CPU - :type model: GraphModule :param loss: current loss module on CPU - :type loss: GraphModule :param evaluator: evaluator :type evaluator: Evaluator :param optimizer: optimizer with state dictionary on CPU diff --git a/hippynn/experiment/serialization.py b/hippynn/experiment/serialization.py index 57f9574c..90ddf1aa 100644 --- a/hippynn/experiment/serialization.py +++ b/hippynn/experiment/serialization.py @@ -193,7 +193,6 @@ def load_model_from_cwd(map_location=None, model_device=None, **kwargs) -> Graph :param model_device: automatically handle device mapping. Defaults to None, defaults to None :type model_device: Union[int, str, torch.device], optional :return: model with reloaded parameters - :rtype: GraphModule """ mapped, model_device = check_mapping_devices(map_location, model_device) kwargs["map_location"] = mapped diff --git a/hippynn/layers/pairs/dispatch.py b/hippynn/layers/pairs/dispatch.py index b14a8b92..853efa05 100644 --- a/hippynn/layers/pairs/dispatch.py +++ b/hippynn/layers/pairs/dispatch.py @@ -4,7 +4,6 @@ from itertools import product import numpy as np -from scipy.spatial import KDTree import torch from .open import PairMemory @@ -137,11 +136,13 @@ def neighbor_list_torch(cutoff: float, coords, cell): return pf, ps, pi def neighbor_list_kdtree(cutoff, coords, cell): - ''' + """ Use KD Tree implementation from scipy.spatial to find pairs under periodic boundary conditions with an orthorhombic cell. - ''' - + """ + # Dev note: Imports are cached, this will only be slow once. + from scipy.spatial import KDTree + # Verify that cell is orthorhombic cell_prod = cell @ cell.T if torch.count_nonzero(cell_prod - torch.diag(torch.diag(cell_prod))): diff --git a/hippynn/molecular_dynamics/__init__.py b/hippynn/molecular_dynamics/__init__.py index 622cb6ce..0807e466 100644 --- a/hippynn/molecular_dynamics/__init__.py +++ b/hippynn/molecular_dynamics/__init__.py @@ -1,7 +1,10 @@ """ -Molecular dynamics driver with great flexibility and customizability regarding which quantities which are evolved -and what algorithms are used to evolve them. Calls a hippynn `Predictor` on current state during each MD step. + +Molecular dynamics driver with great flexibility and customizability regarding which quantities which are evolved +and what algorithms are used to evolve them. Calls a hippynn `Predictor` on current state during each MD step. + """ + from .md import MolecularDynamics, Variable, NullUpdater, VelocityVerlet, LangevinDynamics diff --git a/hippynn/molecular_dynamics/md.py b/hippynn/molecular_dynamics/md.py index 1375fd77..58bca979 100644 --- a/hippynn/molecular_dynamics/md.py +++ b/hippynn/molecular_dynamics/md.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Optional from functools import singledispatchmethod from copy import copy @@ -25,26 +26,20 @@ def __init__( data: dict[str, torch.Tensor], model_input_map: dict[str, str] = dict(), updater: VariableUpdater = None, - device: torch.device = None, - dtype: torch.dtype = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ) -> None: """ :param name: name for variable - :type name: str :param data: dictionary of tracked data in the form `value_name: value` - :type data: dict[str, torch.Tensor] :param model_input_map: dictionary of correspondences between data tracked by Variable and inputs to the HIP-NN model in the form `hipnn-db_name: variable-data-key`, defaults to dict() - :type model_input_map: dict[str, str], optional :param updater: object which will update the data of the Variable over the course of the MD simulation, defaults to None - :type updater: VariableUpdater, optional :param device: device on which to keep data, defaults to None - :type device: torch.device, optional :param dtype: dtype for float type data, defaults to None - :type dtype: torch.dtype, optional - """ + """ self.name = name self.data = data self.model_input_map = model_input_map @@ -164,22 +159,19 @@ def variable(self, variable): ) self._variable = variable - def pre_step(self, dt): + def pre_step(self, dt: float): """Updates to variables performed during each step of MD simulation before HIPNN model evaluation :param dt: timestep - :type dt: float """ pass - def post_step(self, dt, model_outputs): + def post_step(self, dt: float, model_outputs: dict): """Updates to variables performed during each step of MD simulation after HIPNN model evaluation :param dt: timestep - :type dt: float :param model_outputs: dictionary of HIPNN model outputs - :type model_outputs: dict - """ + """ pass @@ -210,25 +202,21 @@ def __init__( """ :param force_db_name: key which will correspond to the force on the corresponding Variable in the HIPNN model output dictionary - :type force_db_name: str :param units_force: amount of eV equal to one in the units used for force output of HIPNN model (eg. if force output in kcal, units_force = ase.units.kcal = 2.6114e22 since 2.6114e22 kcal = 1 eV), by default ase.units.eV = 1, defaults to ase.units.eV - :type units_force: float, optional :param units_acc: amount of Ang/fs^2 equal to one in the units used for acceleration in the corresponding Variable, by default units.Ang/(1.0 ** 2) = 1, defaults to ase.units.Ang/(1.0**2) - :type units_acc: float, optional - """ + """ self.force_key = force_db_name self.force_factor = units_force / units_acc - def pre_step(self, dt): + def pre_step(self, dt: float): """Updates to variables performed during each step of MD simulation before HIPNN model evaluation :param dt: timestep - :type dt: float - """ + """ self.variable.data["velocity"] = self.variable.data["velocity"] + 0.5 * dt * self.variable.data["acceleration"] self.variable.data["position"] = self.variable.data["position"] + self.variable.data["velocity"] * dt try: @@ -236,14 +224,12 @@ def pre_step(self, dt): except KeyError: pass - def post_step(self, dt, model_outputs): + def post_step(self, dt: float, model_outputs: dict): """Updates to variables performed during each step of MD simulation after HIPNN model evaluation :param dt: timestep - :type dt: float :param model_outputs: dictionary of HIPNN model outputs - :type model_outputs: dict - """ + """ self.variable.data["force"] = model_outputs[self.force_key].to(self.variable.device) if len(self.variable.data["force"].shape) == len(self.variable.data["mass"].shape): self.variable.data["acceleration"] = self.variable.data["force"].detach() / self.variable.data["mass"] * self.force_factor @@ -266,29 +252,23 @@ def __init__( force_db_name: str, temperature: float, frix: float, - units_force=ase.units.eV, - units_acc=ase.units.Ang / (1.0**2), - seed: int = None, + units_force: float = ase.units.eV, + units_acc: float = ase.units.Ang / (1.0**2), + seed: Optional[int] = None, ): """ :param force_db_name: key which will correspond to the force on the corresponding Variable in the HIPNN model output dictionary - :type force_db_name: str :param temperature: temperature for Langevin algorithm - :type temperature: float :param frix: friction coefficient for Langevin algorithm - :type frix: float :param units_force: amount of eV equal to one in the units used for force output of HIPNN model (eg. if force output in kcal, units_force = ase.units.kcal = 2.6114e22 since 2.6114e22 kcal = 1 eV), by default ase.units.eV = 1, defaults to ase.units.eV - :type units_force: float, optional :param units_acc: amount of Ang/fs^2 equal to one in the units used for acceleration in the corresponding Variable, by default units.Ang/(1.0 ** 2) = 1, defaults to ase.units.Ang/(1.0**2) - :type units_acc: float, optional :param seed: used to set seed for reproducibility, defaults to None - :type seed: int, optional - """ + """ self.force_key = force_db_name self.force_factor = units_force / units_acc @@ -299,12 +279,11 @@ def __init__( if seed is not None: torch.manual_seed(seed) - def pre_step(self, dt): + def pre_step(self, dt:float): """Updates to variables performed during each step of MD simulation before HIPNN model evaluation :param dt: timestep - :type dt: float - """ + """ self.variable.data["position"] = self.variable.data["position"] + self.variable.data["velocity"] * dt @@ -314,14 +293,13 @@ def pre_step(self, dt): self.variable.data["unwrapped_position"] = self.variable.data["unwrapped_position"] + self.variable.data["velocity"] * dt except KeyError: self.variable.data["unwrapped_position"] = copy(self.variable.data["position"]) - def post_step(self, dt, model_outputs): - """Updates to variables performed during each step of MD simulation after HIPNN model evaluation + def post_step(self, dt: float, model_outputs: dict): + """ + Updates to variables performed during each step of MD simulation after HIPNN model evaluation :param dt: timestep - :type dt: float :param model_outputs: dictionary of HIPNN model outputs - :type model_outputs: dict - """ + """ self.variable.data["force"] = model_outputs[self.force_key].to(self.variable.device) @@ -348,19 +326,15 @@ def __init__( self, variables: list[Variable], model: Predictor, - device: torch.device = None, - dtype: torch.dtype = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): """ :param variables: list of Variable objects which will be tracked during simulation - :type variables: list[Variable] :param model: HIPNN Predictor - :type model: Predictor :param device: device to move variables and model to, defaults to None - :type device: torch.device, optional :param dtype: dtype to convert all float type variable data and model parameters to, defaults to None - :type dtype: torch.dtype, optional - """ + """ self.variables = variables self.model = model @@ -484,19 +458,15 @@ def _update_data(self, model_outputs: dict): self._data[f"output_{key}"].append(value.cpu().detach()[0]) except KeyError: self._data[f"output_{key}"] = [value.cpu().detach()[0]] - - def run(self, dt: float, n_steps: int, record_every: int = None): + def run(self, dt: float, n_steps: int, record_every: Optional[int] = None): """Run `n_steps` of MD algorithm. :param dt: timestep - :type dt: float :param n_steps: number of steps to execute - :type n_steps: int :param record_every: frequency at which to store the data at a step in memory, record_every = 1 means every step will be stored, defaults to None - :type record_every: int, optional - """ + """ for i in progress_bar(range(n_steps)): model_outputs = self._step(dt) diff --git a/pyproject.toml b/pyproject.toml index c95a1da9..b2173b6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ dependencies=[ docs=[ "sphinx", "sphinx_rtd_theme", - "ase", ] full=[ "ase", @@ -37,4 +36,5 @@ full=[ "graphviz", "h5py", "lightning", + "scipy", ] \ No newline at end of file diff --git a/tests/progress_settings.py b/tests/progress_settings.py new file mode 100644 index 00000000..71bb8686 --- /dev/null +++ b/tests/progress_settings.py @@ -0,0 +1,23 @@ + +import hippynn + + +def trigger_progress(): + print(hippynn.settings.PROGRESS) + for _ in hippynn.tools.progress_bar(range(30_000_000)): + _ = _-1 + +hippynn.reload_settings(PROGRESS=None) +trigger_progress() + +hippynn.reload_settings(PROGRESS=True) +trigger_progress() + +hippynn.reload_settings(PROGRESS=False) +trigger_progress() + +hippynn.reload_settings(PROGRESS="tqdm") +trigger_progress() + +hippynn.reload_settings(PROGRESS=0.01) +trigger_progress()