diff --git a/data/data_loading.py b/data/data_loading.py index aaae9ff3..bd63db9a 100644 --- a/data/data_loading.py +++ b/data/data_loading.py @@ -34,11 +34,11 @@ from data.complex import Cochain, CochainBatch, Complex, ComplexBatch from data.datasets import ( load_sr_graph_dataset, load_tu_graph_dataset, load_zinc_graph_dataset, load_ogb_graph_dataset, - load_ring_transfer_dataset, load_ring_lookup_dataset) + load_ring_transfer_dataset, load_ring_lookup_dataset, load_pep_f_graph_dataset, load_pep_s_graph_dataset) from data.datasets import ( SRDataset, ClusterDataset, TUDataset, ComplexDataset, FlowDataset, OceanDataset, ZincDataset, CSLDataset, OGBDataset, RingTransferDataset, RingLookupDataset, - DummyDataset, DummyMolecularDataset) + DummyDataset, DummyMolecularDataset, PeptidesFunctionalDataset, PeptidesStructuralDataset) class Collater(object): @@ -133,19 +133,24 @@ def load_dataset(name, root=os.path.join(ROOT_DIR, 'datasets'), max_dim=2, fold= fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None)) elif name == 'PROTEINS': dataset = TUDataset(os.path.join(root, name), name, max_dim=max_dim, num_classes=2, - fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None)) + fold=fold, degree_as_tag=False, include_down_adj=kwargs['include_down_adj'], + init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None)) elif name == 'NCI1': dataset = TUDataset(os.path.join(root, name), name, max_dim=max_dim, num_classes=2, - fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None)) + fold=fold, degree_as_tag=False, include_down_adj=kwargs['include_down_adj'], + init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None)) elif name == 'NCI109': dataset = TUDataset(os.path.join(root, name), name, max_dim=max_dim, num_classes=2, - fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None)) + fold=fold, degree_as_tag=False, include_down_adj=kwargs['include_down_adj'], + init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None)) elif name == 'PTC': dataset = TUDataset(os.path.join(root, name), name, max_dim=max_dim, num_classes=2, - fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None)) + fold=fold, degree_as_tag=False, include_down_adj=kwargs['include_down_adj'], + init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None)) elif name == 'MUTAG': dataset = TUDataset(os.path.join(root, name), name, max_dim=max_dim, num_classes=2, - fold=fold, degree_as_tag=False, init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None)) + fold=fold, degree_as_tag=False, include_down_adj=kwargs['include_down_adj'], + init_method=init_method, max_ring_size=kwargs.get('max_ring_size', None)) elif name == 'FLOW': dataset = FlowDataset(os.path.join(root, name), name, num_points=kwargs['flow_points'], train_samples=1000, val_samples=200, train_orient=kwargs['train_orient'], @@ -159,9 +164,11 @@ def load_dataset(name, root=os.path.join(ROOT_DIR, 'datasets'), max_dim=2, fold= dataset = RingLookupDataset(os.path.join(root, name), nodes=kwargs['max_ring_size']) elif name == 'ZINC': dataset = ZincDataset(os.path.join(root, name), max_ring_size=kwargs['max_ring_size'], + include_down_adj=kwargs['include_down_adj'], use_edge_features=kwargs['use_edge_features'], n_jobs=n_jobs) elif name == 'ZINC-FULL': dataset = ZincDataset(os.path.join(root, name), subset=False, max_ring_size=kwargs['max_ring_size'], + include_down_adj=kwargs['include_down_adj'], use_edge_features=kwargs['use_edge_features'], n_jobs=n_jobs) elif name == 'CSL': dataset = CSLDataset(os.path.join(root, name), max_ring_size=kwargs['max_ring_size'], @@ -172,11 +179,17 @@ def load_dataset(name, root=os.path.join(ROOT_DIR, 'datasets'), max_dim=2, fold= official_name = 'ogbg-'+name.lower() dataset = OGBDataset(os.path.join(root, name), official_name, max_ring_size=kwargs['max_ring_size'], use_edge_features=kwargs['use_edge_features'], simple=kwargs['simple_features'], - init_method=init_method, n_jobs=n_jobs) + include_down_adj=kwargs['include_down_adj'], init_method=init_method, n_jobs=n_jobs) elif name == 'DUMMY': dataset = DummyDataset(os.path.join(root, name)) elif name == 'DUMMYM': dataset = DummyMolecularDataset(os.path.join(root, name)) + elif name == 'PEPTIDES-F': + dataset = PeptidesFunctionalDataset(os.path.join(root, name), max_ring_size=kwargs['max_ring_size'], + include_down_adj=kwargs['include_down_adj'], init_method=init_method, n_jobs=n_jobs) + elif name == 'PEPTIDES-S': + dataset = PeptidesStructuralDataset(os.path.join(root, name), max_ring_size=kwargs['max_ring_size'], + include_down_adj=kwargs['include_down_adj'], init_method=init_method, n_jobs=n_jobs) else: raise NotImplementedError(name) return dataset @@ -217,6 +230,12 @@ def load_graph_dataset(name, root=os.path.join(ROOT_DIR, 'datasets'), fold=0, ** elif name == 'ZINC': graph_list, train_ids, val_ids, test_ids = load_zinc_graph_dataset(root=root) data = (graph_list, train_ids, val_ids, test_ids, 1) + elif name == 'PEPTIDES-F': + graph_list, train_ids, val_ids, test_ids = load_pep_f_graph_dataset(root=root) + data = (graph_list, train_ids, val_ids, test_ids, 2) + elif name == 'PEPTIDES-S': + graph_list, train_ids, val_ids, test_ids = load_pep_s_graph_dataset(root=root) + data = (graph_list, train_ids, val_ids, test_ids, 2) elif name == 'ZINC-FULL': graph_list, train_ids, val_ids, test_ids = load_zinc_graph_dataset(root=root, subset=False) data = (graph_list, train_ids, val_ids, test_ids, 1) diff --git a/data/datasets/__init__.py b/data/datasets/__init__.py index f7a955cf..b1d388f5 100644 --- a/data/datasets/__init__.py +++ b/data/datasets/__init__.py @@ -8,6 +8,8 @@ from data.datasets.dummy import DummyDataset, DummyMolecularDataset from data.datasets.csl import CSLDataset from data.datasets.ogb import OGBDataset, load_ogb_graph_dataset +from data.datasets.peptides_functional import PeptidesFunctionalDataset, load_pep_f_graph_dataset +from data.datasets.peptides_structural import PeptidesStructuralDataset, load_pep_s_graph_dataset from data.datasets.ringtransfer import RingTransferDataset, load_ring_transfer_dataset from data.datasets.ringlookup import RingLookupDataset, load_ring_lookup_dataset diff --git a/data/datasets/ogb.py b/data/datasets/ogb.py index 408e8001..9ace4ea1 100644 --- a/data/datasets/ogb.py +++ b/data/datasets/ogb.py @@ -10,14 +10,16 @@ class OGBDataset(InMemoryComplexDataset): """This is OGB graph-property prediction. This are graph-wise classification tasks.""" def __init__(self, root, name, max_ring_size, use_edge_features=False, transform=None, - pre_transform=None, pre_filter=None, init_method='sum', simple=False, n_jobs=2): + pre_transform=None, pre_filter=None, init_method='sum', + include_down_adj=False, simple=False, n_jobs=2): self.name = name self._max_ring_size = max_ring_size self._use_edge_features = use_edge_features self._simple = simple self._n_jobs = n_jobs super(OGBDataset, self).__init__(root, transform, pre_transform, pre_filter, - max_dim=2, init_method=init_method, cellular=True) + max_dim=2, init_method=init_method, + include_down_adj=include_down_adj, cellular=True) self.data, self.slices, idx, self.num_tasks = self.load_dataset() self.train_ids = idx['train'] self.val_ids = idx['valid'] diff --git a/data/datasets/peptides_functional.py b/data/datasets/peptides_functional.py new file mode 100644 index 00000000..cd626e15 --- /dev/null +++ b/data/datasets/peptides_functional.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue May 2 21:37:42 2023 + +@author: renz +""" + +import hashlib +import os.path as osp +import os +import pickle +import shutil + +import pandas as pd +import torch +from ogb.utils import smiles2graph +from ogb.utils.torch_util import replace_numpy_with_torchtensor +from ogb.utils.url import decide_download +from torch_geometric.data import Data, download_url +from torch_geometric.data import InMemoryDataset +from data.utils import convert_graph_dataset_with_rings +from data.datasets import InMemoryComplexDataset +from tqdm import tqdm + + +class PeptidesFunctionalDataset(InMemoryComplexDataset): + """ + PyG dataset of 15,535 peptides represented as their molecular graph + (SMILES) with 10-way multi-task binary classification of their + functional classes. + + The goal is use the molecular representation of peptides instead + of amino acid sequence representation ('peptide_seq' field in the file, + provided for possible baseline benchmarking but not used here) to test + GNNs' representation capability. + + The 10 classes represent the following functional classes (in order): + ['antifungal', 'cell_cell_communication', 'anticancer', + 'drug_delivery_vehicle', 'antimicrobial', 'antiviral', + 'antihypertensive', 'antibacterial', 'antiparasitic', 'toxic'] + + Args: + root (string): Root directory where the dataset should be saved. + smiles2graph (callable): A callable function that converts a SMILES + string into a graph object. We use the OGB featurization. + * The default smiles2graph requires rdkit to be installed * + """ + def __init__(self, root, max_ring_size, smiles2graph=smiles2graph, + transform=None, pre_transform=None, pre_filter=None, + include_down_adj=False, init_method='sum', n_jobs=2): + self.original_root = root + self.smiles2graph = smiles2graph + self.folder = osp.join(root, 'peptides-functional') + + self.url = 'https://www.dropbox.com/s/ol2v01usvaxbsr8/peptide_multi_class_dataset.csv.gz?dl=1' + self.version = '701eb743e899f4d793f0e13c8fa5a1b4' # MD5 hash of the intended dataset file + self.url_stratified_split = 'https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1' + self.md5sum_stratified_split = '5a0114bdadc80b94fc7ae974f13ef061' + + # Check version and update if necessary. + release_tag = osp.join(self.folder, self.version) + if osp.isdir(self.folder) and (not osp.exists(release_tag)): + print(f"{self.__class__.__name__} has been updated.") + if input("Will you update the dataset now? (y/N)\n").lower() == 'y': + shutil.rmtree(self.folder) + + self.name = 'peptides_functional' + self._max_ring_size = max_ring_size + self._use_edge_features = True + self._n_jobs = n_jobs + super(PeptidesFunctionalDataset, self).__init__(root, transform, pre_transform, pre_filter, + max_dim=2, init_method=init_method, include_down_adj=include_down_adj, + cellular=True, num_classes=1) + + self.data, self.slices, idx, self.num_tasks = self.load_dataset() + self.train_ids = idx['train'] + self.val_ids = idx['val'] + self.test_ids = idx['test'] + + self.num_node_type = 9 + self.num_edge_type = 3 + + @property + def raw_file_names(self): + return 'peptide_multi_class_dataset.csv.gz' + + @property + def processed_file_names(self): + return [f'{self.name}_complex.pt', f'{self.name}_idx.pt', f'{self.name}_tasks.pt'] + + + @property + def processed_dir(self): + """Overwrite to change name based on edge and simple feats""" + directory = super(PeptidesFunctionalDataset, self).processed_dir + suffix1 = f"_{self._max_ring_size}rings" if self._cellular else "" + suffix2 = "-E" if self._use_edge_features else "" + return directory + suffix1 + suffix2 + + + def _md5sum(self, path): + hash_md5 = hashlib.md5() + with open(path, 'rb') as f: + buffer = f.read() + hash_md5.update(buffer) + return hash_md5.hexdigest() + + def download(self): + if decide_download(self.url): + path = download_url(self.url, self.raw_dir) + # Save to disk the MD5 hash of the downloaded file. + hash = self._md5sum(path) + if hash != self.version: + raise ValueError("Unexpected MD5 hash of the downloaded file") + open(osp.join(self.root, hash), 'w').close() + # Download train/val/test splits. + path_split1 = download_url(self.url_stratified_split, self.root) + assert self._md5sum(path_split1) == self.md5sum_stratified_split + old_df_name = osp.join(self.raw_dir, + 'peptide_multi_class_dataset.csv.gz?dl=1') + new_df_name = osp.join(self.raw_dir, + 'peptide_multi_class_dataset.csv.gz') + + + old_split_file = osp.join(self.root, + "splits_random_stratified_peptide.pickle?dl=1") + new_split_file = osp.join(self.root, + "splits_random_stratified_peptide.pickle") + os.rename(old_df_name, new_df_name) + os.rename(old_split_file, new_split_file) + + else: + print('Stop download.') + exit(-1) + + def load_dataset(self): + """Load the dataset from here and process it if it doesn't exist""" + print("Loading dataset from disk...") + data, slices = torch.load(self.processed_paths[0]) + idx = torch.load(self.processed_paths[1]) + tasks = torch.load(self.processed_paths[2]) + return data, slices, idx, tasks + + def process(self): + data_df = pd.read_csv(osp.join(self.raw_dir, + 'peptide_multi_class_dataset.csv.gz')) + smiles_list = data_df['smiles'] + + print('Converting SMILES strings into graphs...') + data_list = [] + for i in tqdm(range(len(smiles_list))): + data = Data() + + smiles = smiles_list[i] + graph = self.smiles2graph(smiles) + + assert (len(graph['edge_feat']) == graph['edge_index'].shape[1]) + assert (len(graph['node_feat']) == graph['num_nodes']) + + data.__num_nodes__ = int(graph['num_nodes']) + data.edge_index = torch.from_numpy(graph['edge_index']).to( + torch.int64) + data.edge_attr = torch.from_numpy(graph['edge_feat']).to( + torch.int64) + data.x = torch.from_numpy(graph['node_feat']).to(torch.int64) + data.y = torch.Tensor([eval(data_df['labels'].iloc[i])]) + + data_list.append(data) + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + + split_idx = self.get_idx_split() + + # NB: the init method would basically have no effect if + # we use edge features and do not initialize rings. + print(f"Converting the {self.name} dataset to a cell complex...") + complexes, _, _ = convert_graph_dataset_with_rings( + data_list, + max_ring_size=self._max_ring_size, + include_down_adj=self.include_down_adj, + init_method=self._init_method, + init_edges=self._use_edge_features, + init_rings=False, + n_jobs=self._n_jobs) + + print(f'Saving processed dataset in {self.processed_paths[0]}...') + torch.save(self.collate(complexes, self.max_dim), self.processed_paths[0]) + + print(f'Saving idx in {self.processed_paths[1]}...') + torch.save(split_idx, self.processed_paths[1]) + + print(f'Saving num_tasks in {self.processed_paths[2]}...') + torch.save(10, self.processed_paths[2]) + + def get_idx_split(self): + """ Get dataset splits. + + Returns: + Dict with 'train', 'val', 'test', splits indices. + """ + split_file = osp.join(self.root, + "splits_random_stratified_peptide.pickle") + with open(split_file, 'rb') as f: + splits = pickle.load(f) + split_dict = replace_numpy_with_torchtensor(splits) + split_dict['valid'] = split_dict['val'] + return split_dict + + +def load_pep_f_graph_dataset(root): + raw_dir = osp.join(root, 'raw') + data_df = pd.read_csv(osp.join(raw_dir, + 'peptide_multi_class_dataset.csv.gz')) + smiles_list = data_df['smiles'] + target_names = ['Inertia_mass_a', 'Inertia_mass_b', 'Inertia_mass_c', + 'Inertia_valence_a', 'Inertia_valence_b', + 'Inertia_valence_c', 'length_a', 'length_b', 'length_c', + 'Spherocity', 'Plane_best_fit'] + # Normalize to zero mean and unit standard deviation. + data_df.loc[:, target_names] = data_df.loc[:, target_names].apply( + lambda x: (x - x.mean()) / x.std(), axis=0) + + print('Converting SMILES strings into graphs...') + data_list = [] + for i in tqdm(range(len(smiles_list))): + data = Data() + + smiles = smiles_list[i] + y = data_df.iloc[i][target_names] + graph = smiles2graph(smiles) + + assert (len(graph['edge_feat']) == graph['edge_index'].shape[1]) + assert (len(graph['node_feat']) == graph['num_nodes']) + + data.__num_nodes__ = int(graph['num_nodes']) + data.edge_index = torch.from_numpy(graph['edge_index']).to( + torch.int64) + data.edge_attr = torch.from_numpy(graph['edge_feat']).to( + torch.int64) + data.x = torch.from_numpy(graph['node_feat']).to(torch.int64) + data.y = torch.Tensor([y]) + + data_list.append(data) + + dataset = InMemoryDataset.collate(data_list) + + #get split file + split_file = osp.join(root, + "splits_random_stratified_peptide.pickle") + with open(split_file, 'rb') as f: + splits = pickle.load(f) + split_dict = replace_numpy_with_torchtensor(splits) + split_dict['valid'] = split_dict['val'] + + return dataset, split_dict['train'], split_dict['valid'], split_dict['test'] \ No newline at end of file diff --git a/data/datasets/peptides_structural.py b/data/datasets/peptides_structural.py new file mode 100644 index 00000000..80fe7a98 --- /dev/null +++ b/data/datasets/peptides_structural.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue May 2 21:37:42 2023 + +@author: renz +""" + +import hashlib +import os.path as osp +import os +import pickle +import shutil + +import pandas as pd +import torch +from ogb.utils import smiles2graph +from ogb.utils.torch_util import replace_numpy_with_torchtensor +from ogb.utils.url import decide_download +from torch_geometric.data import Data, download_url +from torch_geometric.data import InMemoryDataset +from data.utils import convert_graph_dataset_with_rings +from data.datasets import InMemoryComplexDataset +from tqdm import tqdm + + +class PeptidesStructuralDataset(InMemoryComplexDataset): + """ + PyG dataset of 15,535 small peptides represented as their molecular + graph (SMILES) with 11 regression targets derived from the peptide's + 3D structure. + + The original amino acid sequence representation is provided in + 'peptide_seq' and the distance between atoms in 'self_dist_matrix' field + of the dataset file, but not used here as any part of the input. + + The 11 regression targets were precomputed from molecule XYZ: + Inertia_mass_[a-c]: The principal component of the inertia of the + mass, with some normalizations. Sorted + Inertia_valence_[a-c]: The principal component of the inertia of the + Hydrogen atoms. This is basically a measure of the 3D + distribution of hydrogens. Sorted + length_[a-c]: The length around the 3 main geometric axis of + the 3D objects (without considering atom types). Sorted + Spherocity: SpherocityIndex descriptor computed by + rdkit.Chem.rdMolDescriptors.CalcSpherocityIndex + Plane_best_fit: Plane of best fit (PBF) descriptor computed by + rdkit.Chem.rdMolDescriptors.CalcPBF + Args: + root (string): Root directory where the dataset should be saved. + smiles2graph (callable): A callable function that converts a SMILES + string into a graph object. We use the OGB featurization. + * The default smiles2graph requires rdkit to be installed * + """ + def __init__(self, root, max_ring_size, smiles2graph=smiles2graph, + transform=None, pre_transform=None, pre_filter=None, + include_down_adj=False, init_method='sum', n_jobs=2): + self.original_root = root + self.smiles2graph = smiles2graph + self.folder = osp.join(root, 'peptides-structural') + + self.url = 'https://www.dropbox.com/s/464u3303eu2u4zp/peptide_structure_dataset.csv.gz?dl=1' + self.version = '9786061a34298a0684150f2e4ff13f47' # MD5 hash of the intended dataset file + self.url_stratified_split = 'https://www.dropbox.com/s/9dfifzft1hqgow6/splits_random_stratified_peptide_structure.pickle?dl=1' + self.md5sum_stratified_split = '5a0114bdadc80b94fc7ae974f13ef061' + + # Check version and update if necessary. + release_tag = osp.join(self.folder, self.version) + if osp.isdir(self.folder) and (not osp.exists(release_tag)): + print(f"{self.__class__.__name__} has been updated.") + if input("Will you update the dataset now? (y/N)\n").lower() == 'y': + shutil.rmtree(self.folder) + + self.name = 'peptides_structural' + self._max_ring_size = max_ring_size + self._use_edge_features = True + self._n_jobs = n_jobs + super(PeptidesStructuralDataset, self).__init__(root, transform, pre_transform, pre_filter, + max_dim=2, init_method=init_method, include_down_adj=include_down_adj, + cellular=True, num_classes=1) + + self.data, self.slices, idx, self.num_tasks = self.load_dataset() + self.train_ids = idx['train'] + self.val_ids = idx['val'] + self.test_ids = idx['test'] + self.num_node_type = 9 + self.num_edge_type = 3 + + @property + def raw_file_names(self): + return 'peptide_structure_dataset.csv.gz' + + @property + def processed_file_names(self): + return [f'{self.name}_complex.pt', f'{self.name}_idx.pt', f'{self.name}_tasks.pt'] + + + @property + def processed_dir(self): + """Overwrite to change name based on edge and simple feats""" + directory = super(PeptidesStructuralDataset, self).processed_dir + suffix1 = f"_{self._max_ring_size}rings" if self._cellular else "" + suffix2 = "-E" if self._use_edge_features else "" + return directory + suffix1 + suffix2 + + + def _md5sum(self, path): + hash_md5 = hashlib.md5() + with open(path, 'rb') as f: + buffer = f.read() + hash_md5.update(buffer) + return hash_md5.hexdigest() + + def download(self): + if decide_download(self.url): + path = download_url(self.url, self.raw_dir) + # Save to disk the MD5 hash of the downloaded file. + hash = self._md5sum(path) + if hash != self.version: + raise ValueError("Unexpected MD5 hash of the downloaded file") + open(osp.join(self.root, hash), 'w').close() + # Download train/val/test splits. + path_split1 = download_url(self.url_stratified_split, self.root) + assert self._md5sum(path_split1) == self.md5sum_stratified_split + + old_split_file = osp.join(self.root, + "splits_random_stratified_peptide_structure.pickle?dl=1") + new_split_file = osp.join(self.root, + "splits_random_stratified_peptide_structure.pickle") + old_df_name = osp.join(self.raw_dir, + 'peptide_structure_dataset.csv.gz?dl=1') + new_df_name = osp.join(self.raw_dir, + 'peptide_structure_dataset.csv.gz') + os.rename(old_split_file, new_split_file) + os.rename(old_df_name, new_df_name) + else: + print('Stop download.') + exit(-1) + + def load_dataset(self): + """Load the dataset from here and process it if it doesn't exist""" + print("Loading dataset from disk...") + data, slices = torch.load(self.processed_paths[0]) + idx = torch.load(self.processed_paths[1]) + tasks = torch.load(self.processed_paths[2]) + return data, slices, idx, tasks + + def process(self): + data_df = pd.read_csv(osp.join(self.raw_dir, + 'peptide_structure_dataset.csv.gz')) + smiles_list = data_df['smiles'] + target_names = ['Inertia_mass_a', 'Inertia_mass_b', 'Inertia_mass_c', + 'Inertia_valence_a', 'Inertia_valence_b', + 'Inertia_valence_c', 'length_a', 'length_b', 'length_c', + 'Spherocity', 'Plane_best_fit'] + # Normalize to zero mean and unit standard deviation. + data_df.loc[:, target_names] = data_df.loc[:, target_names].apply( + lambda x: (x - x.mean()) / x.std(), axis=0) + + print('Converting SMILES strings into graphs...') + data_list = [] + for i in tqdm(range(len(smiles_list))): + data = Data() + + smiles = smiles_list[i] + y = data_df.iloc[i][target_names] + graph = self.smiles2graph(smiles) + + assert (len(graph['edge_feat']) == graph['edge_index'].shape[1]) + assert (len(graph['node_feat']) == graph['num_nodes']) + + data.__num_nodes__ = int(graph['num_nodes']) + data.edge_index = torch.from_numpy(graph['edge_index']).to( + torch.int64) + data.edge_attr = torch.from_numpy(graph['edge_feat']).to( + torch.int64) + data.x = torch.from_numpy(graph['node_feat']).to(torch.int64) + data.y = torch.Tensor([y]) + + data_list.append(data) + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + split_idx = self.get_idx_split() + + # NB: the init method would basically have no effect if + # we use edge features and do not initialize rings. + print(f"Converting the {self.name} dataset to a cell complex...") + complexes, _, _ = convert_graph_dataset_with_rings( + data_list, + max_ring_size=self._max_ring_size, + include_down_adj=self.include_down_adj, + init_method=self._init_method, + init_edges=self._use_edge_features, + init_rings=False, + n_jobs=self._n_jobs) + + print(f'Saving processed dataset in {self.processed_paths[0]}...') + torch.save(self.collate(complexes, self.max_dim), self.processed_paths[0]) + + print(f'Saving idx in {self.processed_paths[1]}...') + torch.save(split_idx, self.processed_paths[1]) + + print(f'Saving num_tasks in {self.processed_paths[2]}...') + torch.save(11, self.processed_paths[2]) + + def get_idx_split(self): + """ Get dataset splits. + + Returns: + Dict with 'train', 'val', 'test', splits indices. + """ + split_file = osp.join(self.root, + "splits_random_stratified_peptide_structure.pickle") + with open(split_file, 'rb') as f: + splits = pickle.load(f) + split_dict = replace_numpy_with_torchtensor(splits) + split_dict['valid'] = split_dict['val'] + return split_dict + + +def load_pep_s_graph_dataset(root): + raw_dir = osp.join(root, 'raw') + data_df = pd.read_csv(osp.join(raw_dir, + 'peptide_structure_dataset.csv.gz')) + smiles_list = data_df['smiles'] + target_names = ['Inertia_mass_a', 'Inertia_mass_b', 'Inertia_mass_c', + 'Inertia_valence_a', 'Inertia_valence_b', + 'Inertia_valence_c', 'length_a', 'length_b', 'length_c', + 'Spherocity', 'Plane_best_fit'] + # Normalize to zero mean and unit standard deviation. + data_df.loc[:, target_names] = data_df.loc[:, target_names].apply( + lambda x: (x - x.mean()) / x.std(), axis=0) + + print('Converting SMILES strings into graphs...') + data_list = [] + for i in tqdm(range(len(smiles_list))): + data = Data() + + smiles = smiles_list[i] + y = data_df.iloc[i][target_names] + graph = smiles2graph(smiles) + + assert (len(graph['edge_feat']) == graph['edge_index'].shape[1]) + assert (len(graph['node_feat']) == graph['num_nodes']) + + data.__num_nodes__ = int(graph['num_nodes']) + data.edge_index = torch.from_numpy(graph['edge_index']).to( + torch.int64) + data.edge_attr = torch.from_numpy(graph['edge_feat']).to( + torch.int64) + data.x = torch.from_numpy(graph['node_feat']).to(torch.int64) + data.y = torch.Tensor([y]) + + data_list.append(data) + + dataset = InMemoryDataset.collate(data_list) + + #get split file + split_file = osp.join(root, + "splits_random_stratified_peptide_structure.pickle") + with open(split_file, 'rb') as f: + splits = pickle.load(f) + split_dict = replace_numpy_with_torchtensor(splits) + split_dict['valid'] = split_dict['val'] + + return dataset, split_dict['train'], split_dict['valid'], split_dict['test'] \ No newline at end of file diff --git a/data/datasets/zinc.py b/data/datasets/zinc.py index 4ccfd066..5060d57e 100644 --- a/data/datasets/zinc.py +++ b/data/datasets/zinc.py @@ -10,14 +10,16 @@ class ZincDataset(InMemoryComplexDataset): """This is ZINC from the Benchmarking GNNs paper. This is a graph regression task.""" def __init__(self, root, max_ring_size, use_edge_features=False, transform=None, - pre_transform=None, pre_filter=None, subset=True, n_jobs=2): + pre_transform=None, pre_filter=None, subset=True, + include_down_adj=False, n_jobs=2): self.name = 'ZINC' self._max_ring_size = max_ring_size self._use_edge_features = use_edge_features self._subset = subset self._n_jobs = n_jobs super(ZincDataset, self).__init__(root, transform, pre_transform, pre_filter, - max_dim=2, cellular=True, num_classes=1) + max_dim=2, cellular=True, + include_down_adj=include_down_adj, num_classes=1) self.data, self.slices, idx = self.load_dataset() self.train_ids = idx[0] diff --git a/data/tu_utils.py b/data/tu_utils.py index 73998d99..439263da 100644 --- a/data/tu_utils.py +++ b/data/tu_utils.py @@ -25,6 +25,7 @@ SOFTWARE. """ +import os import networkx as nx import numpy as np import torch diff --git a/exp/parser.py b/exp/parser.py index 7b120806..54244d1c 100644 --- a/exp/parser.py +++ b/exp/parser.py @@ -19,6 +19,8 @@ def get_parser(): help='model, possible choices: cin, dummy, ... (default: cin)') parser.add_argument('--use_coboundaries', type=str, default='False', help='whether to use coboundary features for up-messages in sparse_cin (default: False)') + parser.add_argument('--include_down_adj', action='store_true', + help='whether to use lower adjacencies (i.e. CIN++ networks) (default: False)') # ^^^ here we explicitly pass it as string as easier to handle in tuning parser.add_argument('--indrop_rate', type=float, default=0.0, help='inputs dropout rate for molec models(default: 0.0)') @@ -141,6 +143,8 @@ def validate_args(args): assert args.graph_norm == 'bn' elif args.dataset.startswith('ZINC'): assert args.model.startswith('embed') + if args.model == 'embed_cin++': + assert args.include_down_adj is True assert args.task_type == 'regression' assert args.minimize assert args.eval_metric == 'mae' @@ -149,7 +153,9 @@ def validate_args(args): elif args.dataset in ['MOLHIV', 'MOLPCBA', 'MOLTOX21', 'MOLTOXCAST', 'MOLMUV', 'MOLBACE', 'MOLBBBP', 'MOLCLINTOX', 'MOLSIDER', 'MOLESOL', 'MOLFREESOLV', 'MOLLIPO']: - assert args.model == 'ogb_embed_sparse_cin' + assert args.model == 'ogb_embed_sparse_cin' or args.model == "ogb_embed_cin++" + if args.model == 'ogb_embed_cin++': + assert args.include_down_adj is True assert args.eval_metric == 'ogbg-'+args.dataset.lower() assert args.jump_mode is None if args.dataset in ['MOLESOL', 'MOLFREESOLV', 'MOLLIPO']: @@ -178,4 +184,3 @@ def validate_args(args): assert not args.untrained assert not args.simple_features assert not args.minimize - diff --git a/exp/run_exp.py b/exp/run_exp.py index e392efc9..29b320eb 100644 --- a/exp/run_exp.py +++ b/exp/run_exp.py @@ -11,8 +11,8 @@ from exp.train_utils import train, eval, Evaluator from exp.parser import get_parser, validate_args from mp.graph_models import GIN0, GINWithJK -from mp.models import CIN0, Dummy, SparseCIN, EdgeOrient, EdgeMPNN, MessagePassingAgnostic -from mp.molec_models import EmbedSparseCIN, OGBEmbedSparseCIN, EmbedSparseCINNoRings, EmbedGIN +from mp.models import CIN0, Dummy, SparseCIN, CINpp, EdgeOrient, EdgeMPNN, MessagePassingAgnostic +from mp.molec_models import EmbedSparseCIN, EmbedCINpp, OGBEmbedSparseCIN, OGBEmbedCINpp, EmbedSparseCINNoRings, EmbedGIN from mp.ring_exp_models import RingSparseCIN, RingGIN @@ -79,6 +79,7 @@ def main(args): flow_points=args.flow_points, flow_classes=args.flow_classes, max_ring_size=args.max_ring_size, use_edge_features=args.use_edge_features, + include_down_adj=args.include_down_adj, simple_features=args.simple_features, n_jobs=args.preproc_jobs, train_orient=args.train_orient, test_orient=args.test_orient) if args.tune: @@ -135,6 +136,22 @@ def main(args): graph_norm=args.graph_norm, # normalization layer readout_dims=readout_dims # readout_dims ).to(device) + elif args.model == 'cin++': + model = CINpp(dataset.num_features_in_dim(0), # num_input_features + dataset.num_classes, # num_classes + args.num_layers, # num_layers + args.emb_dim, # hidden + dropout_rate=args.drop_rate, # dropout rate + max_dim=dataset.max_dim, # max_dim + jump_mode=args.jump_mode, # jump mode + nonlinearity=args.nonlinearity, # nonlinearity + readout=args.readout, # readout + final_readout=args.final_readout, # final readout + apply_dropout_before=args.drop_position, # where to apply dropout + use_coboundaries=use_coboundaries, # whether to use coboundaries in up-msg + graph_norm=args.graph_norm, # normalization layer + readout_dims=readout_dims # readout_dims + ).to(device) elif args.model == 'ring_sparse_cin': model = RingSparseCIN( dataset.num_features_in_dim(0), # num_input_features @@ -198,7 +215,7 @@ def main(args): nonlinearity=args.nonlinearity, # nonlinearity dropout_rate=args.drop_rate, # dropout rate fully_invar=args.fully_orient_invar - ).to(device) + ).to(device) elif args.model == 'edge_mpnn': model = EdgeMPNN(1, dataset.num_classes, @@ -208,7 +225,7 @@ def main(args): nonlinearity=args.nonlinearity, # nonlinearity dropout_rate=args.drop_rate, # dropout rate fully_invar=args.fully_orient_invar, - ).to(device) + ).to(device) elif args.model == 'embed_sparse_cin': model = EmbedSparseCIN(dataset.num_node_type, # The number of atomic types dataset.num_edge_type, # The number of bond types @@ -227,6 +244,24 @@ def main(args): graph_norm=args.graph_norm, # normalization layer readout_dims=readout_dims # readout_dims ).to(device) + elif args.model == 'embed_cin++': + model = EmbedCINpp(atom_types=dataset.num_node_type, # The number of atomic types + bond_types=dataset.num_edge_type, # The number of bond types + out_size=dataset.num_classes, # num_classes + num_layers=args.num_layers, # num_layers + hidden=args.emb_dim, # hidden + dropout_rate=args.drop_rate, # dropout rate + max_dim=dataset.max_dim, # max_dim + jump_mode=args.jump_mode, # jump mode + nonlinearity=args.nonlinearity, # nonlinearity + readout=args.readout, # readout + final_readout=args.final_readout, # final readout + apply_dropout_before=args.drop_position, # where to apply dropout + use_coboundaries=use_coboundaries, + embed_edge=args.use_edge_features, + graph_norm=args.graph_norm, # normalization layer + readout_dims=readout_dims # readout_dims + ).to(device) elif args.model == 'embed_sparse_cin_no_rings': model = EmbedSparseCINNoRings(dataset.num_node_type, # The number of atomic types dataset.num_edge_type, # The number of bond types @@ -253,7 +288,7 @@ def main(args): readout=args.readout, # readout apply_dropout_before=args.drop_position, # where to apply dropout embed_edge=args.use_edge_features, - ).to(device) + ).to(device) # TODO: handle this as above elif args.model == 'ogb_embed_sparse_cin': model = OGBEmbedSparseCIN(dataset.num_tasks, # out_size @@ -272,6 +307,23 @@ def main(args): graph_norm=args.graph_norm, # normalization layer readout_dims=readout_dims # readout_dims ).to(device) + elif args.model == 'ogb_embed_cin++': + model = OGBEmbedCINpp(dataset.num_tasks, # out_size + args.num_layers, # num_layers + args.emb_dim, # hidden + dropout_rate=args.drop_rate, # dropout_rate + indropout_rate=args.indrop_rate, # in-dropout_rate + max_dim=dataset.max_dim, # max_dim + jump_mode=args.jump_mode, # jump_mode + nonlinearity=args.nonlinearity, # nonlinearity + readout=args.readout, # readout + final_readout=args.final_readout, # final readout + apply_dropout_before=args.drop_position, # where to apply dropout + use_coboundaries=use_coboundaries, # whether to use coboundaries + embed_edge=args.use_edge_features, # whether to use edge feats + graph_norm=args.graph_norm, # normalization layer + readout_dims=readout_dims # readout_dims + ).to(device) else: raise ValueError('Invalid model type {}.'.format(args.model)) diff --git a/exp/scripts/cin++-molhiv-small.sh b/exp/scripts/cin++-molhiv-small.sh new file mode 100644 index 00000000..6ab7dbc7 --- /dev/null +++ b/exp/scripts/cin++-molhiv-small.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +python -m exp.run_mol_exp \ + --device 0 \ + --start_seed 0 \ + --stop_seed 9 \ + --exp_name cin++-molhiv-small \ + --dataset MOLHIV \ + --model ogb_embed_cin++ \ + --include_down_adj \ + --use_coboundaries True \ + --indrop_rate 0.0 \ + --drop_rate 0.5 \ + --graph_norm bn \ + --drop_position lin2 \ + --nonlinearity relu \ + --readout mean \ + --final_readout sum \ + --lr 0.0001 \ + --lr_scheduler None \ + --num_layers 2 \ + --emb_dim 48 \ + --batch_size 128 \ + --epochs 150 \ + --num_workers 2 \ + --preproc_jobs 32 \ + --task_type bin_classification \ + --eval_metric ogbg-molhiv \ + --max_dim 2 \ + --max_ring_size 6 \ + --init_method sum \ + --train_eval_period 10 \ + --use_edge_features \ + --dump_curves diff --git a/exp/scripts/cin++-molhiv.sh b/exp/scripts/cin++-molhiv.sh new file mode 100644 index 00000000..ef22ffa6 --- /dev/null +++ b/exp/scripts/cin++-molhiv.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +python -m exp.run_mol_exp \ + --device 0 \ + --start_seed 0 \ + --stop_seed 9 \ + --exp_name cin++-molhiv \ + --dataset MOLHIV \ + --model ogb_embed_cin++ \ + --include_down_adj \ + --use_coboundaries True \ + --indrop_rate 0.0 \ + --drop_rate 0.5 \ + --graph_norm bn \ + --drop_position lin2 \ + --nonlinearity relu \ + --readout mean \ + --final_readout sum \ + --lr 0.0001 \ + --lr_scheduler None \ + --num_layers 2 \ + --emb_dim 64 \ + --batch_size 128 \ + --epochs 150 \ + --num_workers 2 \ + --preproc_jobs 32 \ + --task_type bin_classification \ + --eval_metric ogbg-molhiv \ + --max_dim 2 \ + --max_ring_size 6 \ + --init_method sum \ + --train_eval_period 10 \ + --use_edge_features \ + --dump_curves diff --git a/exp/scripts/cin++-nci109.sh b/exp/scripts/cin++-nci109.sh new file mode 100644 index 00000000..0e52add4 --- /dev/null +++ b/exp/scripts/cin++-nci109.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +python -m exp.run_tu_exp \ + --device 0 \ + --exp_name cin++-nci109 \ + --dataset NCI109 \ + --train_eval_period 50 \ + --epochs 150 \ + --batch_size 32 \ + --drop_rate 0.0 \ + --drop_position lin2 \ + --emb_dim 64 \ + --max_dim 2 \ + --final_readout sum \ + --init_method mean \ + --jump_mode 'cat' \ + --lr 0.001 \ + --graph_norm bn \ + --model cin++ \ + --include_down_adj \ + --nonlinearity relu \ + --num_layers 4 \ + --readout sum \ + --max_ring_size 6 \ + --task_type classification \ + --eval_metric accuracy \ + --lr_scheduler 'StepLR' \ + --lr_scheduler_decay_rate 0.5 \ + --lr_scheduler_decay_steps 20 \ + --use_coboundaries True \ + --dump_curves \ + --preproc_jobs 4 \ No newline at end of file diff --git a/exp/scripts/cin++-pep-f.sh b/exp/scripts/cin++-pep-f.sh new file mode 100644 index 00000000..e4112e55 --- /dev/null +++ b/exp/scripts/cin++-pep-f.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +python -m exp.run_mol_exp \ + --device 0 \ + --start_seed 0 \ + --stop_seed 3 \ + --exp_name cwn-pep-f-500k \ + --dataset PEPTIDES-F \ + --model ogb_embed_cin++ \ + --include_down_adj \ + --use_coboundaries True \ + --indrop_rate 0.0 \ + --drop_rate 0.15 \ + --graph_norm bn \ + --drop_position lin2 \ + --nonlinearity relu \ + --readout sum \ + --final_readout sum \ + --lr 0.001 \ + --num_layers 3 \ + --emb_dim 64 \ + --batch_size 128 \ + --epochs 1000 \ + --num_workers 0 \ + --preproc_jobs 32 \ + --task_type bin_classification \ + --eval_metric ap \ + --max_dim 2 \ + --max_ring_size 8 \ + --lr_scheduler 'ReduceLROnPlateau' \ + --init_method sum \ + --train_eval_period 10 \ + --use_edge_features \ + --lr_scheduler_patience 15 \ + --dump_curves diff --git a/exp/scripts/cin++-pep-s.sh b/exp/scripts/cin++-pep-s.sh new file mode 100644 index 00000000..acac6772 --- /dev/null +++ b/exp/scripts/cin++-pep-s.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +python -m exp.run_mol_exp \ + --device 0 \ + --start_seed 0 \ + --stop_seed 3 \ + --exp_name cwn-pep-s-500k \ + --dataset PEPTIDES-S \ + --model ogb_embed_cin++ \ + --include_down_adj \ + --use_coboundaries True \ + --indrop_rate 0.0 \ + --drop_rate 0.0 \ + --graph_norm bn \ + --drop_position lin2 \ + --nonlinearity relu \ + --readout mean \ + --final_readout sum \ + --lr 0.001 \ + --num_layers 3 \ + --emb_dim 64 \ + --batch_size 128 \ + --epochs 1000 \ + --num_workers 0 \ + --preproc_jobs 32 \ + --task_type regression \ + --eval_metric mae \ + --max_dim 2 \ + --max_ring_size 8 \ + --lr_scheduler 'ReduceLROnPlateau' \ + --init_method sum \ + --minimize \ + --early_stop \ + --train_eval_period 10 \ + --use_edge_features \ + --lr_scheduler_patience 20 \ + --dump_curves diff --git a/exp/scripts/cin++-zinc-500k.sh b/exp/scripts/cin++-zinc-500k.sh new file mode 100644 index 00000000..136337e3 --- /dev/null +++ b/exp/scripts/cin++-zinc-500k.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +python -m exp.run_mol_exp \ + --device 0 \ + --start_seed 0 \ + --stop_seed 9 \ + --exp_name cin++-zinc-500k \ + --dataset ZINC \ + --train_eval_period 20 \ + --epochs 1000 \ + --batch_size 128 \ + --drop_rate 0.0 \ + --drop_position lin2 \ + --emb_dim 64 \ + --max_dim 2 \ + --final_readout sum \ + --init_method sum \ + --lr 0.001 \ + --graph_norm bn \ + --model embed_cin++ \ + --include_down_adj \ + --nonlinearity relu \ + --num_layers 3 \ + --readout sum \ + --max_ring_size 18 \ + --task_type regression \ + --eval_metric mae \ + --minimize \ + --lr_scheduler 'ReduceLROnPlateau' \ + --use_coboundaries True \ + --use_edge_features \ + --early_stop \ + --lr_scheduler_patience 20 \ + --dump_curves \ + --preproc_jobs 32 diff --git a/exp/scripts/cin++-zinc-small.sh b/exp/scripts/cin++-zinc-small.sh new file mode 100644 index 00000000..3e7703df --- /dev/null +++ b/exp/scripts/cin++-zinc-small.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +python -m exp.run_mol_exp \ + --device 0 \ + --start_seed 0 \ + --stop_seed 9 \ + --exp_name cin++-zinc \ + --dataset ZINC \ + --train_eval_period 20 \ + --epochs 1000 \ + --batch_size 128 \ + --drop_rate 0.0 \ + --drop_position lin2 \ + --emb_dim 48 \ + --max_dim 2 \ + --final_readout sum \ + --init_method sum \ + --lr 0.001 \ + --graph_norm bn \ + --model embed_cin++ \ + --include_down_adj \ + --nonlinearity relu \ + --num_layers 2 \ + --readout sum \ + --max_ring_size 18 \ + --task_type regression \ + --eval_metric mae \ + --minimize \ + --lr_scheduler 'ReduceLROnPlateau' \ + --use_coboundaries True \ + --use_edge_features \ + --early_stop \ + --lr_scheduler_patience 20 \ + --dump_curves \ + --preproc_jobs 32 diff --git a/exp/scripts/cin++-zinc.sh b/exp/scripts/cin++-zinc.sh new file mode 100644 index 00000000..7a7f25eb --- /dev/null +++ b/exp/scripts/cin++-zinc.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +python -m exp.run_mol_exp \ + --device 0 \ + --start_seed 0 \ + --stop_seed 9 \ + --exp_name cin++-zinc \ + --dataset ZINC \ + --train_eval_period 20 \ + --epochs 1000 \ + --batch_size 128 \ + --drop_rate 0.0 \ + --drop_position lin2 \ + --emb_dim 128 \ + --max_dim 2 \ + --final_readout sum \ + --init_method sum \ + --lr 0.001 \ + --graph_norm bn \ + --model embed_cin++ \ + --include_down_adj \ + --nonlinearity relu \ + --num_layers 4 \ + --readout sum \ + --max_ring_size 18 \ + --task_type regression \ + --eval_metric mae \ + --minimize \ + --lr_scheduler 'ReduceLROnPlateau' \ + --use_coboundaries True \ + --use_edge_features \ + --early_stop \ + --lr_scheduler_patience 20 \ + --dump_curves \ + --preproc_jobs 32 diff --git a/exp/train_utils.py b/exp/train_utils.py index 61559f27..44591760 100644 --- a/exp/train_utils.py +++ b/exp/train_utils.py @@ -153,6 +153,8 @@ def __init__(self, metric, **kwargs): self.p_norm = kwargs.get('p', 2) elif metric == 'accuracy': self.eval_fn = self._accuracy + elif metric == 'ap': + self.eval_fn = self._ap elif metric == 'mae': self.eval_fn = self._mae elif metric.startswith('ogbg-mol'): @@ -183,6 +185,15 @@ def _accuracy(self, input_dict, **kwargs): assert y_pred is not None metric = met.accuracy_score(y_true, y_pred) return metric + + def _ap(self, input_dict, **kwargs): + y_true = input_dict['y_true'] + y_pred = input_dict['y_pred'] + assert y_true is not None + assert y_pred is not None + metric = met.average_precision_score(y_true, y_pred) + return metric + def _mae(self, input_dict, **kwargs): y_true = input_dict['y_true'] diff --git a/mp/layers.py b/mp/layers.py index b6c23980..2c1f2c80 100644 --- a/mp/layers.py +++ b/mp/layers.py @@ -1,6 +1,6 @@ import torch -from typing import Callable, Optional +from typing import Any, Callable, Optional from torch import Tensor from mp.cell_mp import CochainMessagePassing, CochainMessagePassingParams from torch_geometric.nn.inits import reset @@ -213,6 +213,52 @@ def message_up(self, up_x_j: Tensor, up_attr: Tensor) -> Tensor: def message_boundary(self, boundary_x_j: Tensor) -> Tensor: return self.msg_boundaries_nn(boundary_x_j) +class CINppCochainConv(SparseCINCochainConv): + """CINppCochainConv + """ + def __init__(self, dim: int, up_msg_size: int, down_msg_size: int, boundary_msg_size: int, + msg_up_nn: Callable[..., Any], msg_boundaries_nn: Callable[..., Any], msg_down_nn: Callable[..., Any], + update_up_nn: Callable[..., Any], update_boundaries_nn: Callable[..., Any], update_down_nn: Callable[..., Any], + combine_nn: Callable[..., Any], eps: float = 0, train_eps: bool = False): + super(CINppCochainConv, self).__init__(dim, up_msg_size, down_msg_size, boundary_msg_size, + msg_up_nn, msg_boundaries_nn, + update_up_nn, update_boundaries_nn, + combine_nn, eps, train_eps) + + self.msg_down_nn = msg_down_nn + self.update_down_nn = update_down_nn + if train_eps: + self.eps3 = torch.nn.Parameter(torch.Tensor([eps])) + else: + self.register_buffer('eps3', torch.Tensor([eps])) + + reset(self.msg_down_nn) + reset(self.update_down_nn) + self.eps3.data.fill_(self.initial_eps) + + + def message_down(self, down_x_j: Tensor, down_attr: Tensor) -> Tensor: + return self.msg_down_nn((down_x_j, down_attr)) + + def forward(self, cochain: CochainMessagePassingParams): + out_up, out_down, out_boundaries = self.propagate(cochain.up_index, cochain.down_index, + cochain.boundary_index, x=cochain.x, + up_attr=cochain.kwargs['up_attr'], + boundary_attr=cochain.kwargs['boundary_attr']) + + # As in GIN, we can learn an injective update function for each multi-set + out_up += (1 + self.eps1) * cochain.x + out_down += (1 + self.eps2) * cochain.x + out_boundaries += (1 + self.eps3) * cochain.x + out_up = self.update_up_nn(out_up) + out_down = self.update_down_nn(out_down) + out_boundaries = self.update_boundaries_nn(out_boundaries) + + # We need to combine the three such that the output is injective + # Because the cross product of countable spaces is countable, then such a function exists. + # And we can learn it with another MLP. + return self.combine_nn(torch.cat([out_up, out_down, out_boundaries], dim=-1)) + class Catter(torch.nn.Module): def __init__(self): @@ -295,6 +341,91 @@ def forward(self, *cochain_params: CochainMessagePassingParams, start_to_process out.append(self.mp_levels[dim].forward(cochain_params[dim])) return out +class CINppConv(SparseCINConv): + """ + """ + def __init__(self, up_msg_size: int, down_msg_size: int, boundary_msg_size: Optional[int], + passed_msg_up_nn: Optional[Callable], passed_msg_down_nn: Optional[Callable], + passed_msg_boundaries_nn: Optional[Callable], + passed_update_up_nn: Optional[Callable], + passed_update_down_nn: Optional[Callable], + passed_update_boundaries_nn: Optional[Callable], + eps: float = 0., train_eps: bool = False, max_dim: int = 2, + graph_norm=BN, use_coboundaries=False, **kwargs): + super(CINppConv, self).__init__(up_msg_size, down_msg_size, boundary_msg_size, + passed_msg_up_nn, passed_msg_boundaries_nn, + passed_update_up_nn, passed_update_boundaries_nn, + eps, train_eps, max_dim, graph_norm, use_coboundaries, **kwargs) + self.max_dim = max_dim + self.mp_levels = torch.nn.ModuleList() + for dim in range(max_dim+1): + msg_up_nn = passed_msg_up_nn + if msg_up_nn is None: + if use_coboundaries: + msg_up_nn = Sequential( + Catter(), + Linear(kwargs['layer_dim'] * 2, kwargs['layer_dim']), + kwargs['act_module']()) + else: + msg_up_nn = lambda xs: xs[0] + + msg_down_nn = passed_msg_down_nn + if msg_down_nn is None: + if use_coboundaries: + msg_down_nn = Sequential( + Catter(), + Linear(kwargs['layer_dim'] * 2, kwargs['layer_dim']), + kwargs['act_module']()) + else: + msg_down_nn = lambda xs: xs[0] + + msg_boundaries_nn = passed_msg_boundaries_nn + if msg_boundaries_nn is None: + msg_boundaries_nn = lambda x: x + + update_up_nn = passed_update_up_nn + if update_up_nn is None: + update_up_nn = Sequential( + Linear(kwargs['layer_dim'], kwargs['hidden']), + graph_norm(kwargs['hidden']), + kwargs['act_module'](), + Linear(kwargs['hidden'], kwargs['hidden']), + graph_norm(kwargs['hidden']), + kwargs['act_module']() + ) + + update_down_nn = passed_update_down_nn + if update_down_nn is None: + update_down_nn = Sequential( + Linear(kwargs['layer_dim'], kwargs['hidden']), + graph_norm(kwargs['hidden']), + kwargs['act_module'](), + Linear(kwargs['hidden'], kwargs['hidden']), + graph_norm(kwargs['hidden']), + kwargs['act_module']() + ) + + update_boundaries_nn = passed_update_boundaries_nn + if update_boundaries_nn is None: + update_boundaries_nn = Sequential( + Linear(kwargs['layer_dim'], kwargs['hidden']), + graph_norm(kwargs['hidden']), + kwargs['act_module'](), + Linear(kwargs['hidden'], kwargs['hidden']), + graph_norm(kwargs['hidden']), + kwargs['act_module']() + ) + combine_nn = Sequential( + Linear(kwargs['hidden']*3, kwargs['hidden']), + graph_norm(kwargs['hidden']), + kwargs['act_module']()) + + mp = CINppCochainConv(dim, up_msg_size, down_msg_size, boundary_msg_size=boundary_msg_size, + msg_up_nn=msg_up_nn, msg_down_nn=msg_down_nn, msg_boundaries_nn=msg_boundaries_nn, update_up_nn=update_up_nn, + update_down_nn=update_down_nn, update_boundaries_nn=update_boundaries_nn, combine_nn=combine_nn, eps=eps, + train_eps=train_eps) + self.mp_levels.append(mp) + class OrientedConv(CochainMessagePassing): def __init__(self, dim: int, up_msg_size: int, down_msg_size: int, diff --git a/mp/models.py b/mp/models.py index 881d0f36..6dc0a489 100644 --- a/mp/models.py +++ b/mp/models.py @@ -4,7 +4,7 @@ from torch.nn import Linear, Sequential, BatchNorm1d as BN from torch_geometric.nn import JumpingKnowledge from mp.layers import ( - CINConv, EdgeCINConv, SparseCINConv, DummyCellularMessagePassing, OrientedConv) + CINConv, EdgeCINConv, SparseCINConv, CINppConv,DummyCellularMessagePassing, OrientedConv) from mp.nn import get_nonlinearity, get_pooling_fn, pool_complex, get_graph_norm from data.complex import ComplexBatch, CochainBatch @@ -256,7 +256,33 @@ def forward(self, data: ComplexBatch, include_partial=False): def __repr__(self): return self.__class__.__name__ +class CINpp(SparseCIN): + """CINpp + """ + def __init__(self, num_input_features, num_classes, num_layers, hidden, + dropout_rate: float = 0.5, max_dim: int = 2, jump_mode=None, + nonlinearity='relu', readout='sum', train_eps=False, + final_hidden_multiplier: int = 2, use_coboundaries=False, + readout_dims=(0, 1, 2), final_readout='sum', + apply_dropout_before='lin2', graph_norm='bn'): + super(CINpp, self).__init__(num_input_features, num_classes, num_layers, hidden, + dropout_rate, max_dim, jump_mode, nonlinearity, + readout, train_eps, final_hidden_multiplier, + use_coboundaries, readout_dims, final_readout, + apply_dropout_before, graph_norm) + self.convs = torch.nn.ModuleList() + act_module = get_nonlinearity(nonlinearity, return_module=True) + for i in range(num_layers): + layer_dim = num_input_features if i == 0 else hidden + self.convs.append( + CINppConv(up_msg_size=layer_dim, down_msg_size=layer_dim, + boundary_msg_size=layer_dim, passed_msg_boundaries_nn=None, passed_msg_up_nn=None, + passed_msg_down_nn=None, passed_update_up_nn=None, passed_update_down_nn=None, + passed_update_boundaries_nn=None, train_eps=train_eps, max_dim=self.max_dim, + hidden=hidden, act_module=act_module, layer_dim=layer_dim, + graph_norm=self.graph_norm, use_coboundaries=use_coboundaries)) + class EdgeCIN0(torch.nn.Module): """ A variant of CIN0 operating up to edge level. It may optionally ignore two_cell features. diff --git a/mp/molec_models.py b/mp/molec_models.py index 9b2a039e..6ffd3240 100644 --- a/mp/molec_models.py +++ b/mp/molec_models.py @@ -3,7 +3,7 @@ from torch.nn import Linear, Embedding, Sequential, BatchNorm1d as BN from torch_geometric.nn import JumpingKnowledge, GINEConv -from mp.layers import InitReduceConv, EmbedVEWithReduce, OGBEmbedVEWithReduce, SparseCINConv +from mp.layers import InitReduceConv, EmbedVEWithReduce, OGBEmbedVEWithReduce, SparseCINConv, CINppConv from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder from data.complex import ComplexBatch from mp.nn import pool_complex, get_pooling_fn, get_nonlinearity, get_graph_norm @@ -164,6 +164,40 @@ def __repr__(self): return self.__class__.__name__ +class EmbedCINpp(EmbedSparseCIN): + """ + Inherit from EmbedSparseCIN and add messages from lower adj cells + """ + + def __init__(self, atom_types, bond_types, out_size, num_layers, hidden, + dropout_rate: float = 0.5, max_dim: int = 2, jump_mode=None, + nonlinearity='relu', readout='sum', train_eps=False, + final_hidden_multiplier: int = 2, readout_dims=(0, 1, 2), + final_readout='sum', apply_dropout_before='lin2', init_reduce='sum', + embed_edge=False, embed_dim=None, use_coboundaries=False, graph_norm='bn'): + super(EmbedCINpp, self).__init__(atom_types, bond_types, out_size, num_layers, + hidden, dropout_rate, max_dim, jump_mode, + nonlinearity, readout, train_eps, + final_hidden_multiplier, readout_dims, + final_readout, apply_dropout_before, + init_reduce, embed_edge, embed_dim, + use_coboundaries, graph_norm) + self.convs = torch.nn.ModuleList() #reset convs to use CINppConv instead of SparseCINConv + act_module = get_nonlinearity(nonlinearity, return_module=True) + + if embed_dim is None: + embed_dim = hidden + + for i in range(num_layers): + layer_dim = embed_dim if i == 0 else hidden + self.convs.append( + CINppConv(up_msg_size=layer_dim, down_msg_size=layer_dim, + boundary_msg_size=layer_dim, passed_msg_boundaries_nn=None, + passed_msg_up_nn=None, passed_msg_down_nn=None, passed_update_up_nn=None, + passed_update_down_nn=None, passed_update_boundaries_nn=None, train_eps=train_eps, + max_dim=self.max_dim, hidden=hidden, act_module=act_module, layer_dim=layer_dim, + graph_norm=self.graph_norm, use_coboundaries=use_coboundaries)) + class OGBEmbedSparseCIN(torch.nn.Module): """ A cellular version of GIN with some tailoring to nimbly work on molecules from the ogbg-mol* dataset. @@ -318,6 +352,36 @@ def forward(self, data: ComplexBatch, include_partial=False): def __repr__(self): return self.__class__.__name__ +class OGBEmbedCINpp(OGBEmbedSparseCIN): + """ + Inherit from EmbedSparseCIN and add messages from lower adj cells + """ + def __init__(self, out_size, num_layers, hidden, dropout_rate: float = 0.5, + indropout_rate: float = 0, max_dim: int = 2, jump_mode=None, + nonlinearity='relu', readout='sum', train_eps=False, + final_hidden_multiplier: int = 2, readout_dims=(0, 1, 2), + final_readout='sum', apply_dropout_before='lin2', init_reduce='sum', + embed_edge=False, embed_dim=None, use_coboundaries=False, graph_norm='bn'): + super().__init__(out_size, num_layers, hidden, dropout_rate, indropout_rate, + max_dim, jump_mode, nonlinearity, readout, train_eps, + final_hidden_multiplier, readout_dims, final_readout, + apply_dropout_before, init_reduce, embed_edge, embed_dim, + use_coboundaries, graph_norm) + self.convs = torch.nn.ModuleList() #reset convs to use CINppConv instead of SparseCINConv + act_module = get_nonlinearity(nonlinearity, return_module=True) + + if embed_dim is None: + embed_dim = hidden + + for i in range(num_layers): + layer_dim = embed_dim if i == 0 else hidden + self.convs.append( + CINppConv(up_msg_size=layer_dim, down_msg_size=layer_dim, + boundary_msg_size=layer_dim, passed_msg_boundaries_nn=None, + passed_msg_up_nn=None, passed_msg_down_nn=None, passed_update_up_nn=None, + passed_update_down_nn=None, passed_update_boundaries_nn=None, train_eps=train_eps, + max_dim=self.max_dim, hidden=hidden, act_module=act_module, layer_dim=layer_dim, + graph_norm=self.graph_norm, use_coboundaries=use_coboundaries)) class EmbedSparseCINNoRings(torch.nn.Module): """