diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 76c89c9..4d8a220 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,13 +1,24 @@ import unittest import tempfile +import networkx as nx + from rnaglib.data_loading import RNADataset from rnaglib.data_loading import FeaturesComputer from rnaglib.representations import GraphRepresentation from rnaglib.transforms import RNAFMTransform +from rnaglib.transforms import RfamTransform +from rnaglib.transforms import Compose class TransformsTest(unittest.TestCase): + def check_ndata(self, g, attribute: str): + _, ndata = next(iter(g.nodes(data=True))) + assert attribute in ndata + + def check_gdata(self, g, attribute: str): + assert attribute in g.graph + @classmethod def setUpClass(self): self.dataset = RNADataset(debug=True) @@ -15,8 +26,18 @@ def setUpClass(self): def test_RNAFMTransform(self): tr = RNAFMTransform() tr(self.dataset[0]) + tr(self.dataset) pass + def test_simple_compose(self): + g = self.dataset[0] + tr_1 = RNAFMTransform() + tr_2 = RfamTransform() + t = Compose([tr_1, tr_2]) + t(self.dataset[0]) + self.check_gdata(g['rna'], 'rfam') + self.check_ndata(g['rna'], 'rnafm') + def test_pre_transform(self): """ Add rnafm embeddings during dataset construction from database, then look up the stored attribute at getitem time.