Skip to content

Commit

Permalink
Add domains, versioning, and tests (#54)
Browse files Browse the repository at this point in the history
Co-authored-by: Lily Wang <[email protected]>
  • Loading branch information
lilyminium and Lily Wang committed Aug 25, 2023
1 parent 43f3d8a commit db76183
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 12 deletions.
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:

# gcn
- dgl >=1.0
- pytorch
- pytorch >=2.0
- pytorch-lightning

# parallelism
Expand Down
1 change: 1 addition & 0 deletions openff/nagl/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ReadoutModule(ImmutableModel):


class ModelConfig(ImmutableModel, FromYamlMixin):
version: typing.Literal["0.1"]
atom_features: typing.List[DiscriminatedAtomFeatureType] = Field(
description="Atom features to use"
)
Expand Down
78 changes: 78 additions & 0 deletions openff/nagl/domains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import typing

from openff.nagl._base.base import ImmutableModel

try:
from pydantic.v1 import Field
except ImportError:
from pydantic import Field

if typing.TYPE_CHECKING:
from openff.toolkit.topology import Molecule

class ChemicalDomain(ImmutableModel):
"""A domain of chemical space to which a molecule can belong
Used for determining if a molecule is represented in the
training data for a given model.
"""
allowed_elements: typing.Tuple[int, ...] = Field(
description="The atomic numbers of the elements allowed in the domain",
default_factory=tuple
)
forbidden_patterns: typing.Tuple[str, ...] = Field(
description="The SMARTS patterns which are forbidden in the domain",
default_factory=tuple
)

def check_molecule(
self,
molecule: "Molecule",
return_error_message: bool = False
) -> typing.Union[bool, typing.Tuple[bool, str]]:
checks = [
self.check_allowed_elements,
self.check_forbidden_patterns
]
for check in checks:
is_allowed, err = check(molecule, return_error_message=True)
if not is_allowed:
if return_error_message:
return False, err
return False
if return_error_message:
return True, ""
return True

def check_allowed_elements(
self,
molecule: "Molecule",
return_error_message: bool = False
) -> typing.Union[bool, typing.Tuple[bool, str]]:
if not self.allowed_elements:
return True
atomic_numbers = [atom.atomic_number for atom in molecule.atoms]
for atomic_number in atomic_numbers:
if atomic_number not in self.allowed_elements:
if return_error_message:
err = f"Molecule contains forbidden element {atomic_number}"
return False, err
return False
if return_error_message:
return True, ""
return True

def check_forbidden_patterns(
self,
molecule: "Molecule",
return_error_message: bool = False
) -> typing.Union[bool, typing.Tuple[bool, str]]:
for pattern in self.forbidden_patterns:
if molecule.chemical_environment_matches(pattern):
err = f"Molecule contains forbidden SMARTS pattern {pattern}"
if return_error_message:
return False, err
return False
if return_error_message:
return True, ""
return True
61 changes: 53 additions & 8 deletions openff/nagl/nn/_models.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import copy
from typing import TYPE_CHECKING, Tuple, Dict, Union, Callable, Literal, Optional
import warnings

import torch
import pytorch_lightning as pl

from openff.utilities.exceptions import MissingOptionalDependencyError
from openff.nagl.nn._containers import ConvolutionModule, ReadoutModule
from openff.nagl.config.model import ModelConfig

from openff.nagl.domains import ChemicalDomain

if TYPE_CHECKING:
from openff.toolkit.topology import Molecule
from openff.nagl.features.atoms import AtomFeature
from openff.nagl.features.bonds import BondFeature
from openff.nagl.molecule._dgl import DGLMoleculeOrBatch
from openff.nagl.nn.postprocess import PostprocessLayer
from openff.nagl.nn.activation import ActivationFunction
from openff.nagl.nn.gcn._base import BaseGCNStack


class BaseGNNModel(pl.LightningModule):
Expand All @@ -41,10 +37,22 @@ def forward(
return readouts

class GNNModel(BaseGNNModel):
def __init__(self, config: ModelConfig):
def __init__(
self,
config: ModelConfig,
chemical_domain: Optional[ChemicalDomain] = None,
):
if not isinstance(config, ModelConfig):
config = ModelConfig(**config)

if chemical_domain is None:
chemical_domain = ChemicalDomain(
allowed_elements=tuple(),
forbidden_patterns=tuple(),
)
elif not isinstance(chemical_domain, ChemicalDomain):
chemical_domain = ChemicalDomain(**chemical_domain)

convolution_module = ConvolutionModule.from_config(
config.convolution,
n_input_features=config.n_atom_features,
Expand All @@ -62,8 +70,13 @@ def __init__(self, config: ModelConfig):
readout_modules=readout_modules,
)

self.save_hyperparameters({"config": config.dict()})
self.save_hyperparameters({
"config": config.dict(),
"chemical_domain": chemical_domain.dict(),
})
self.config = config
self.chemical_domain = chemical_domain


@classmethod
def from_yaml(cls, filename):
Expand All @@ -84,6 +97,8 @@ def compute_properties(
self,
molecule: "Molecule",
as_numpy: bool = True,
check_domains: bool = False,
error_if_unsupported: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Compute the trained property for a molecule.
Expand All @@ -95,11 +110,29 @@ def compute_properties(
as_numpy: bool
Whether to return the result as a numpy array.
If ``False``, the result will be a ``torch.Tensor``.
check_domains: bool
Whether to check if the molecule is similar
to the training data.
error_if_unsupported: bool
Whether to raise an error if the molecule
is not represented in the training data.
This is only used if ``check_domains`` is ``True``.
If ``False``, a warning will be raised instead.
Returns
-------
result: Dict[str, torch.Tensor] or Dict[str, numpy.ndarray]
"""
if check_domains:
is_supported, error = self.chemical_domain.check_molecule(
molecule, return_error_message=True
)
if not is_supported:
if error_if_unsupported:
raise ValueError(error)
else:
warnings.warn(error)

try:
values = self._compute_properties_dgl(molecule)
except (MissingOptionalDependencyError, TypeError):
Expand All @@ -113,6 +146,8 @@ def compute_property(
molecule: "Molecule",
readout_name: Optional[str] = None,
as_numpy: bool = True,
check_domains: bool = False,
error_if_unsupported: bool = True,
):
"""
Compute the trained property for a molecule.
Expand All @@ -128,6 +163,14 @@ def compute_property(
as_numpy: bool
Whether to return the result as a numpy array.
If ``False``, the result will be a ``torch.Tensor``.
check_domains: bool
Whether to check if the molecule is similar
to the training data.
error_if_unsupported: bool
Whether to raise an error if the molecule
is not represented in the training data.
This is only used if ``check_domains`` is ``True``.
If ``False``, a warning will be raised instead.
Returns
-------
Expand All @@ -136,6 +179,8 @@ def compute_property(
properties = self.compute_properties(
molecule=molecule,
as_numpy=as_numpy,
check_domains=check_domains,
error_if_unsupported=error_if_unsupported,
)
if readout_name is None:
if len(properties) == 1:
Expand Down
Binary file modified openff/nagl/tests/data/example_am1bcc_model.pt
Binary file not shown.
1 change: 1 addition & 0 deletions openff/nagl/tests/data/example_model_config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
version: "0.1"
atom_features:
- categories:
- C
Expand Down
1 change: 1 addition & 0 deletions openff/nagl/tests/data/example_training_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
model:
version: "0.1"
atom_features:
- name: atomic_element
categories: ["C", "O", "H", "N", "S", "F", "Br", "Cl", "I", "P"]
Expand Down
1 change: 1 addition & 0 deletions openff/nagl/tests/data/example_training_config_lazy.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
model:
version: "0.1"
atom_features:
- name: atomic_element
categories: ["C", "O", "H", "N", "S", "F", "Br", "Cl", "I", "P"]
Expand Down
32 changes: 29 additions & 3 deletions openff/nagl/tests/nn/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from openff.nagl.nn._pooling import PoolAtomFeatures, PoolBondFeatures
from openff.nagl.nn.postprocess import ComputePartialCharges
from openff.nagl.nn._sequential import SequentialLayers
from openff.nagl.domains import ChemicalDomain
from openff.nagl.tests.data.files import (
# EXAMPLE_AM1BCC_MODEL_STATE_DICT,
# MODEL_CONFIG_V7,
EXAMPLE_AM1BCC_MODEL,
)
from openff.nagl.features.atoms import (
Expand Down Expand Up @@ -107,6 +106,9 @@ class TestGNNModel:
@pytest.fixture()
def am1bcc_model(self):
model = GNNModel.load(EXAMPLE_AM1BCC_MODEL, eval_mode=True)
model.chemical_domain = ChemicalDomain(
allowed_elements=(1, 6)
)
# model = GNNModel.from_yaml(MODEL_CONFIG_V7)
# model.load_state_dict(torch.load(EXAMPLE_AM1BCC_MODEL_STATE_DICT))
# model.eval()
Expand Down Expand Up @@ -139,7 +141,7 @@ def test_init(self):
]

model = GNNModel(
{
{ "version": "0.1",
"atom_features": atom_features,
"bond_features": bond_features,
"convolution": {
Expand Down Expand Up @@ -214,6 +216,30 @@ def test_compute_properties(self, am1bcc_model, openff_methane_uncharged, expect
assert len(charges) == 1
assert_allclose(charges["am1bcc_charges"], expected_methane_charges, atol=1e-5)

def test_compute_properties_check_domains(self, am1bcc_model, openff_methane_uncharged):
am1bcc_model.compute_properties(
openff_methane_uncharged,
check_domains=True,
error_if_unsupported=True,
)

def test_compute_properties_warning_domains(self, am1bcc_model, openff_methyl_methanoate):
with pytest.warns(UserWarning):
am1bcc_model.compute_properties(
openff_methyl_methanoate,
check_domains=True,
error_if_unsupported=False,
)

def test_compute_properties_error_domains(self, am1bcc_model, openff_methyl_methanoate):
with pytest.raises(ValueError):
am1bcc_model.compute_properties(
openff_methyl_methanoate,
check_domains=True,
error_if_unsupported=True,
)


def test_load(self, openff_methane_uncharged, expected_methane_charges):
model = GNNModel.load(EXAMPLE_AM1BCC_MODEL, eval_mode=True)
assert isinstance(model, GNNModel)
Expand Down
80 changes: 80 additions & 0 deletions openff/nagl/tests/test_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest

from openff.nagl.domains import ChemicalDomain

class TestChemicalDomain:

@pytest.mark.parametrize(
"elements", [
(1, 6, 8),
(1, 6, 8, 9, 17, 35),
]
)
def test_check_allowed_elements(self, elements, openff_methyl_methanoate):
domain = ChemicalDomain(allowed_elements=elements)
assert domain.check_allowed_elements(
molecule=openff_methyl_methanoate
)

@pytest.mark.parametrize(
"elements", [
(8,),
(6, 8, 9, 17, 35),
]
)
def test_check_allowed_elements_fail_noerr(self, elements, openff_methyl_methanoate):
domain = ChemicalDomain(allowed_elements=elements)
assert not domain.check_allowed_elements(
molecule=openff_methyl_methanoate
)

def test_check_allowed_elements_fail_err(self, openff_methyl_methanoate):
domain = ChemicalDomain(allowed_elements=(8,))
allowed, err = domain.check_allowed_elements(
molecule=openff_methyl_methanoate, return_error_message=True
)
assert not allowed
assert err == "Molecule contains forbidden element 6"

@pytest.mark.parametrize(
"patterns", [
("[*:1]#[*:2]",),
("[*:1]#[*:2]", "[#1:1]=[*:2]"),
]
)
def test_check_forbidden_patterns(self, patterns, openff_methyl_methanoate):
domain = ChemicalDomain(forbidden_patterns=patterns)
assert domain.check_forbidden_patterns(
molecule=openff_methyl_methanoate
)

@pytest.mark.parametrize(
"patterns", [
("[*:1]~[*:2]",),
("[#1:1]-[#6:2]", "[#1:1]#[*:2]"),
]
)
def test_check_forbidden_patterns_fail_noerr(self, patterns, openff_methyl_methanoate):
domain = ChemicalDomain(forbidden_patterns=patterns)
assert not domain.check_forbidden_patterns(
molecule=openff_methyl_methanoate
)

def test_check_forbidden_patterns_fail_err(self, openff_methyl_methanoate):
domain = ChemicalDomain(forbidden_patterns=("[*:1]~[*:2]",))
allowed, err = domain.check_forbidden_patterns(
molecule=openff_methyl_methanoate, return_error_message=True
)
assert not allowed
assert err == "Molecule contains forbidden SMARTS pattern [*:1]~[*:2]"

def test_check_molecule_err(self, openff_methyl_methanoate):
domain = ChemicalDomain(
allowed_elements=(8, 6, 1),
forbidden_patterns=("[*:1]~[*:2]",)
)
allowed, err = domain.check_molecule(
molecule=openff_methyl_methanoate, return_error_message=True
)
assert not allowed
assert err == "Molecule contains forbidden SMARTS pattern [*:1]~[*:2]"

0 comments on commit db76183

Please sign in to comment.