From f8d37b1ceb81a7e96bac35149c63a277febc1ddc Mon Sep 17 00:00:00 2001 From: cgoliver Date: Thu, 19 Sep 2024 18:07:29 -0400 Subject: [PATCH] docs --- docs/source/command_line.rst | 29 ++ docs/source/index.rst | 11 +- docs/source/tuto_custom_task.rst | 251 +++++++++++++++++- docs/source/tuto_transforms.rst | 151 +++++++++++ .../tasks/RBP_Node/protein_binding_site.py | 6 +- .../tasks/RNA_CM/chemical_modification.py | 8 +- src/rnaglib/tasks/RNA_Family/rfam.py | 5 +- src/rnaglib/tasks/task.py | 5 +- src/rnaglib/transforms/filter/filters.py | 81 ++++-- 9 files changed, 497 insertions(+), 50 deletions(-) create mode 100644 docs/source/command_line.rst create mode 100644 docs/source/tuto_transforms.rst diff --git a/docs/source/command_line.rst b/docs/source/command_line.rst new file mode 100644 index 0000000..fa2347d --- /dev/null +++ b/docs/source/command_line.rst @@ -0,0 +1,29 @@ +Command Line Utilities +------------------------- + + +We provide several command line utilities which you can use to set up +the rnaglib environment. + + +Database building +~~~~~~~~~~~~~~~~~~~~~~~~ + +To build or update a local database of RNA structures along with their annotations, +you can use the ``rnaglib_prepare_data`` command line utility. + + +:: + + $ rnaglib_prepare_data -s structures/ --tag first_build -o builds/ -d + +Database Indexing +~~~~~~~~~~~~~~~~~~~ + +Indexing a database collects information about annotations present in a +database to enable rapid access of particular RNAs given some desired +properties.:: + + $ rnaglib_index + + diff --git a/docs/source/index.rst b/docs/source/index.rst index 6957e9f..f52acc4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -18,17 +18,10 @@ :caption: Tutorials :hidden: - Machine Learning on Benchmark Datasets What is an RNA 2.5D graph? A tour of RNA 2.5D graphs - -.. toctree:: - :maxdepth: 2 - :caption: Advanced Tutorials - :hidden: - - Creating custom tasks and splits - Add your own annotations and features + ML Tasks API + RNA Transforms How is the data built? diff --git a/docs/source/tuto_custom_task.rst b/docs/source/tuto_custom_task.rst index 08fd6a6..49420f5 100644 --- a/docs/source/tuto_custom_task.rst +++ b/docs/source/tuto_custom_task.rst @@ -1,17 +1,59 @@ -Anatomy of a Task +Using Tasks API +-------------------------- + +An ``rnaglib.Task`` object packages everything you need to train a model for a particular biological problem. + +The key components of a ``Task`` you will use are: + +* A dedicated dataset for the task +* Train/validation/test dataloaders +* Model evaluator + +When you instantiate the task, the task either calls the ``process()`` and ``split()`` methods to compute the necessary data or you have already run this before and the result was stored in the ``root`` directory and loading should be instantaneous. + +Once loading is complete you only need to select a tensor representation (e.g. graph, voxel, point cloud) and encode the underlying dataset with it and then iterate through the train loader. Note that whenever you update the task's dataset you should call ``set_loaders()`` so that changes in the dataset are reflected in the data served by the loaders:: + + + + from rnaglib.tasks import ChemicalModification + from rnaglib.transforms import GraphRepresentation + + ta = ChemicalModification(root='cm') + ta.dataset.add_representation(GraphRepresentation(framework='pyg')) + ta.set_loaders() + + for batch in ta.train_loader: + pred = ta.dummy_model(batch['graph']) + ... + + + metrics = ta.evaluate(ta.dummy_model) + + +Once you have completed training you can pass your model to the task's ``evaluate()`` method which will return a dictionary of metrics and performance values. + +.. note:: + + Each task provides a ``dummy_model`` variable which you can use for testing out the task. It simply returns a random prediction of the appropriate shape. + + + +Building Custom Tasks ------------------------------------- -If you would like to propose a new prediction task for the machine learning community. We provide the customizable ``Task`` class. +If you would like to propose a new prediction task for the machine learning community. You just have to implement a few methos in a subclass of the``Task`` class. An instance of the ``Task`` class packages the following attributes: - ``dataset``: full collection of RNAs to use in the task. - ``splitter``: method for partitioning the dataset into train, validation, and test subsets. -- ``features_computer``: method for setting and encoding input and target variables. +- ``target_vars``: method for setting and encoding input and target variables. - ``evaluate``: method which accepts a model and returns performance metrics. +Once the task processing is complete, all task data is dumped into ``root`` which is a path passed to the task init method. -Here is a template for a custom task:: + +Here is a minimal template for a custom task:: from rnaglib.tasks import Task from rnaglib.data_loading import RNADataset @@ -19,15 +61,208 @@ Here is a template for a custom task:: class MyTask(Task): - def build_dataset(self) -> RNADataset: + def __init__(self, root): + super().__init__(root) + + def process(self) -> RNADataset: # build the task's dataset + # ... pass + @property def default_splitter() -> Splitter: # return a splitter object to build train/val/test splits + # ... pass - - def features_computer() -> FeaturesComputer: + + def get_task_vars() -> FeaturesComputer: # computes the task's default input and target variables + # managed by creating a FeaturesComputer object + # ... + pass + + +In this tutorial we will walk through the steps to create a task with the aim of predicting for each residue, whether or not it will be chemically modified, and a more advanced example we will build the task of predicting the Rfam classification of an RNA. + +Types of tasks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Tasks can operate at the residue, edge, and whole RNA level. +Biolerplate for evaluation and loading would be affected depending on the choice of level. +For that reason we create sub-classes of the ``Task`` clas which you can use to avoid re-coding such things. + + +Since chemical modifications are applied to residues, Let's build a residue-level binary classification task.:: + + from rnaglib.tasks import ResidueClassificationTask + + class ChemicalModification(ResidueClasificationTask): + .... + + + + +1. Create the task's dataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Each task needs to define which RNAs to use. Typically this involves filtering a whole dataset of available RNAs by certain attributes to retain only the ones that contain certain annotations or pass certain criteria (e.g. size, origin, resolution, etc.). + +You are free to do this in any way you like as long as after ``Task.process()`` is called, a list of ``.json`` graphs storing the RNA annotations the task needs is dumped into ``{root}/dataset``. + +To make things easier you can take advantage of the ``rnaglib.Tranforms`` library which provides funcionality for manipulating datasets of RNAs. + +Let's define a ``Task.process()`` method which builds a dataset with a single criterion: + +* Only keep RNAs that contain at least one chemically modified residue + +The ``Transforms`` library provides a filter which checks that an RNA's residues are of a desired value. :: + + from rnaglib.data_loading import RNADataset + from rnaglib.tasks import ResidueClassificationTask + from rnaglib.transforms import ResidueAttributeFilter + from rnaglib.transforms import PDBIDNameTransform + + class ChemicalModification(ResidueClasificationTask): + def process(self) -> RNADataset: + # grab a full set of available RNAs + rnas = RNADataset() + + filter = ResidueAttributeFilter(attribute='is_modified', + val_checker=lambda val: val == True + ) + + rnas = filter(rnas) + + rnas = PDBIDNameTransform()(rnas) + dataset = RNADataset(rnas=[r["rna"] for r in rnas]) + return dataset + + pass + + +Applying the filter gives us a new list containing only the RNAs that passed the filter. The last thing we need to do is assign a ``name`` value to each RNA so that they can be properly managed by the ``RNADataset``. We assign the PDBID as the name of each item in our dataset using the ``PDBIDNameTransform``. + +Now we just create a new ``RNADataset`` object using the reduced list. The dataset object requires a list and not a generator so we just unroll before passing it. + +That's it now you just return the new ``RNADataset`` object. + +2. Set the task's variables +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Apart from the RNAs themselves, the task needs to know which variables are relevant. In particular we need to set the prediction target. Additionally we can set some default input features, which are always provided. The user can always add more input features if he/she desires by manipulating ``task.dataset.features_computer`` but at the minimum we need to define target variables.:: + + from rnaglib.data_loading import RNADataset + from rnaglib.tasks import ResidueClassificationTask + from rnaglib.transforms import ResidueAttributeFilter + from rnaglib.transforms import PDBIDNameTransform + from rnaglib.transforms import FeaturesComputer + + class ChemicalModification(ResidueClasificationTask): + def process(self) -> RNADataset: + ... + pass + + def get_task_vars(self) -> FeaturesComputer: + return FeaturesComputer(nt_features=['nt_code'], nt_targets=['is_modified']) + + +Here we simply have a nucleotide level target so we pass the ``'is_modified'`` attribute to the ``FeaturesComputer`` object. This will take care of selecting the residue when encoding the RNA into tensor form. In addition we provide the nucleotide identity (``'nt_code'``) as a default input feature. + + +3. Train/val/test splits +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The last necessary step is to define the train, validation and test subsets of the whole dataset. Once these are set, the task's boilerplate will take care of generating the appropriate loaders. + +To set the splits, you implement the ``default_splitter()`` method which returns a ``Splitter`` object. A ``Splitter`` object is simply a callable which accepts a dataset and returns three lists of indices representing the train, validation and test subsets. + +You can select from the library of implemented splitters of implement your own. + +For this example, we will split the RNAs by structural similarity using RNA-align.:: + + from rnaglib.data_loading import RNADataset + from rnaglib.tasks import ResidueClassificationTask + + from rnaglib.transforms import ResidueAttributeFilter + from rnaglib.transforms import PDBIDNameTransform + from rnaglib.transforms import FeaturesComputer + + from rnaglib.splitters import Splitter, RNAalignSplitter + + class ChemicalModification(ResidueClasificationTask): + def process(self) -> RNADataset: + ... + pass + + def get_task_vars(self) -> FeaturesComputer: + return FeaturesComputer(nt_features=['nt_code'], nt_targets=['is_modified']) + + @property + def default_splitter(self) -> Splitter + return RNAalignSplitter(similarity_threshold=0.6) + + +Now our splits will guarantee a maximum structural similarity of 0.6 between them. + +Check out the Splitter class for a quick guide on how to create your own splitters. + +Note that this is only setting the default method to use for splitting the dataset. If a user wants to try a different splitter it can be pased to the task's init. + +That's it! Your task is now fully defined and can be used in model training and evaluation. + +Here is the ful task implementation:: + + + from rnaglib.data_loading import RNADataset + from rnaglib.tasks import ResidueClassificationTask + from rnaglib.transforms import FeaturesComputer + from rnaglib.transforms import ResidueAttributeFilter + from rnaglib.transforms import PDBIDNameTransform + from rnaglib.splitters import Splitter, RNAalignSplitter + + + class ChemicalModification(ResidueClassificationTask): + """Residue-level binary classification task to predict whether or not a given + residue is chemically modified. + """ + + target_var = "is_modified" + + def __init__(self, root, splitter=None, **kwargs): + super().__init__(root=root, splitter=splitter, **kwargs) + + def get_task_vars(self): + return FeaturesComputer(nt_targets=self.target_var) + + def process(self): + rnas = ResidueAttributeFilter( + attribute=self.target_var, value_checker=lambda val: val == True + )(RNADataset(debug=self.debug)) + rnas = PDBIDNameTransform()(rnas) + dataset = RNADataset(rnas=[r["rna"] for r in rnas]) + return dataset + + def default_splitter(self) -> Splitter: + return RNAalignSplitter(similarity_threshold=0.6) + + +Customize Splitting +------------------------ + +We provide some pre-defined splitters for sequence and structure-based splitting. If you have other criteria for splitting you can subclass the ``Splitter`` class. All you have to do is implement the ``__call__()`` method which takes a dataset and returns three lists of indices:: + + class Splitter: + def __init__(self, split_train=0.7, split_valid=0.15, split_test=0.15): + assert sum([split_train, split_valid, split_test]) == 1, "Splits don't sum to 1." + self.split_train = split_train + self.split_valid = split_valid + self.split_test = split_test pass - + + def __call__(self, dataset): + return None, None, None + + +The ``__call__(self, dataset)`` method returns three lists of indices from the given ``dataset`` object. + +The splitter can be initiated with the desired proportions of the dataset for each subset. diff --git a/docs/source/tuto_transforms.rst b/docs/source/tuto_transforms.rst new file mode 100644 index 0000000..c0e0c03 --- /dev/null +++ b/docs/source/tuto_transforms.rst @@ -0,0 +1,151 @@ +Manipulating RNAs: Transforms API +------------------------------------------ + +The ``Transforms`` API handles any operations that modify RNA dictionaries. + +Reminder, an RNA dictionary is the item provided by an ``RNADataset()[i]`` and looks like:: + + >>> from rnaglib.data_loading import RNADataset + >>> dataset = RNADataset(debug=True) + >>> rna = dataset[3] + {'rna': , ..., } + + +Transforms are ``Callable`` objects which operate on individual RNAs or collections of RNAs. Let's see by importing a transform that does nothing.:: + + >>> from rnaglib.transforms import Transform + >>> t = Transform() + >>> new_rna = t(rna) + >>> new_rnas = t(dataset) + +To customize the behaviour of the transform you can usually pass arguments to the object constructor. Looking inside a transform all you have is:: + + class Transform: + def __init__(self): + # any setup for the transform + + def forward(self, data: dict): + # apply operation to the RNA dictionary + pass + + + +.. note:: + Transforms can usually be applied in parallel for faster computing by passing `parallel=True` to the constructor. + + +Transforms come in several flavors depending on the kind of manipulation they apply to the provided data: + +* **Annotation**: adds or removes annotations from the RNA (e.g. query a database and store results in the RNA) +* **Filter**: accept or reject certain RNAs from a collection based on some criteria (e.g. remove RNAs that are too large) +* **Partition**: generate a collection of substructure from a whole RNA (e.g. break up an RNA into individual chains) +* **Featurize**: convert RNA annotations to tensors for learning. +* **Represent**: compute tensor-based representations of RNAs (e.g. convert to voxel grid) + + +Annotation Transforms: add/remove data from RNAs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Annotation transforms update the attributes of an RNA, usually by adding a new key/value pair to node/edge/graph-level annotations. This is useful when the annotations provided by default are not enough. + +For example, if you want to store the Rfam class of an RNA you can use the ``RfamTransform``:: + + >>> from rnaglib.transforms import RfamTransform + >>> from rnaglib.data_loading import RNADataset + >>> dset = RNADataset(debug=True) + >>> t = RfamTransform() + >>> t(dset) + >>> dset[0]['rna'].graph['rfam'] + 'RF0005' + +For annotation transforms, the ``forward()`` method modifies the given RNA dictionary, optionally returns it if you don't want to work in-place. + +Filter Transforms: narrow down datasets +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Filters reduce a collection of RNAs based on a given test criterion. For example, if you want to only keep RNAs that have a certain maximum size.:: + + >>> from rnaglib.transforms import SizeFilter + >>> t = SizeFilter(max_size=50) + >>> rnas = t(dset) + +The new ``rnas`` list will contain only the RNAs that have fewer than 50 residues. + +To implement a filtering transform, the ``forward()`` method accepts an RNA dictionary and returns True or False. + + +Partition Transforms: focus on substructures +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to only keep certain substructures of an RNA. For example by extracting only binding sites, or splitting into individual chians, use the partition transforms family.:: + + >>> from rnaglib.transforms import ChainSplit + >>> from rnaglib.data_loading import RNADataset + >>> t = ChainSplit() + >>> dset = RNADataset(debug=True) + >>> t(dset) + +Now instead of the dataset containing a list of RNAs that can each have multiple chains, the nuew list will contain possibly more entries but each entry only consists of a single chain. + +To implement a partition transform, the ``forward()`` method defines a **generator** which accepts a single RNA dictionary and yields substructures from the given RNA. + +Represent Transform: geometric representations of RNAs for learning +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For deep learning, the raw RNA data has to be encoded into a mathematical structure known as a **representation** which hold the geometric information (e.g. base pairing graph, voxel grid, point cloud). :: + + >>> from rnaglib.transforms import GraphRepresentation + >>> from rnaglib.transforms import PointCloudRepresentation + >>> t1 = GraphRepresentation() + >>> t2 = PointCloudRepresentation() + >>> dset = RNADataset(debug=True, representations=[t1, t2]) + >>> dset[0] + {'rna': ..., 'graph': ..., 'point_cloud'...} + + +You can apply the representation directly to an RNA as with the other transforms. However most of the time you will be passing it to a dataset so that when you load the RNAs they are converted to the necessary representation. + +Check the documentation for arguments to representations. You will typically pass an ID of the deep learning framework you need for the representation (e.g. ``GraphRepresentation(framework='pyg')`` to use pytorch geometric). + +Featurize: encode attributes for ML models +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Finally, a special transform is used to convert raw RNA attibutes which have on constraints on their format (e.g. they can be strings representing the Rfam family or nucleotide value) to tensors. The feature encoder transforms can do this both for input features provided to the model at learning time, or as target features which are the variable the model is trying to predict.:: + + >>> from rnaglib.transforms import FeaturesComputer + >>> from rnaglib.data_loading import RNADataset + >>> ft = FeaturesComputer(nt_features=['nt_code'], nt_targets=['is_modified']) + >>> dataset = RNADataset(debug=True) + >>> features_dict = ft(dataset[0]) + {'nt_features': Tensor(...), 'nt_targets': Tensor(...)} + +The above features computer, when called on an RNA graph returns a dictionary of tensors representing the nucleotide ID and chemical modification status. + +Most likely you won't use this directly and instead pass the featuers computer to the ``RNADatsaet`` object so that the featuers are served by the loader.:: + + >>> RNADataset(features_computer=features_computer) + + +Additionally, you can load a task and choose which variables you want to feed your model:: + + >>> from rnaglib.tasks import ChemicalModification + >>> ta = ChemicalModification() + >>> ta.dataset.features_computer.add_feature('alpha') + +The features computer has a method to add and remove features so you can go beyond the default features provided by the task. + +Combining Transforms +~~~~~~~~~~~~~~~~~~~~~~~ + +Transforms of the same kind can be stitched together to avoid repeated iterations on the same list of RNAs using the ``Compose`` transform.:: + + >>> from rnaglib.transforms import FilterTransform + >>> from rnaglib.trasforms import RfamTransform + >>> from rnaglib.transforms import RNAFMTransform + >>> from rnaglib.data_loading import RNADataet + >>> dataset = RNADataset(debug=True) + >>> t = [RfamTransform(), RNAFMTransform()] + >>> t(dataset) + + +Each type of transform has its own compose object to deal with the slightly different behaviour. If you are composing filters use the ``ComposeFilters`` or composing partitions use the ``ComposePartitions``. diff --git a/src/rnaglib/tasks/RBP_Node/protein_binding_site.py b/src/rnaglib/tasks/RBP_Node/protein_binding_site.py index f7df6ac..2d90954 100644 --- a/src/rnaglib/tasks/RBP_Node/protein_binding_site.py +++ b/src/rnaglib/tasks/RBP_Node/protein_binding_site.py @@ -18,7 +18,7 @@ class ProteinBindingSiteDetection(ResidueClassificationTask): def __init__(self, root, splitter=None, **kwargs): super().__init__(root=root, splitter=splitter, **kwargs) - def get_features_computer(self): + def get_task_vars(self): return FeaturesComputer(nt_targets=self.target_var) def process(self): @@ -27,7 +27,9 @@ def process(self): # build the filters ribo_filter = RibosomalFilter() - non_bind_filter = ResidueAttributeFilter(attribute=self.target_var) + non_bind_filter = ResidueAttributeFilter( + attribute=self.target_var, value_checker=lambda val: val is not None + ) filters = ComposeFilters([ribo_filter, non_bind_filter]) # assign a name to each remaining RNA diff --git a/src/rnaglib/tasks/RNA_CM/chemical_modification.py b/src/rnaglib/tasks/RNA_CM/chemical_modification.py index 01678b4..f3ddf86 100644 --- a/src/rnaglib/tasks/RNA_CM/chemical_modification.py +++ b/src/rnaglib/tasks/RNA_CM/chemical_modification.py @@ -15,13 +15,13 @@ class ChemicalModification(ResidueClassificationTask): def __init__(self, root, splitter=None, **kwargs): super().__init__(root=root, splitter=splitter, **kwargs) - def get_features_computer(self): + def get_task_vars(self): return FeaturesComputer(nt_targets=self.target_var) def process(self): - rnas = ResidueAttributeFilter(attribute=self.target_var)( - RNADataset(debug=self.debug) - ) + rnas = ResidueAttributeFilter( + attribute=self.target_var, value_checker=lambda val: val == True + )(RNADataset(debug=self.debug)) rnas = PDBIDNameTransform()(rnas) dataset = RNADataset(rnas=[r["rna"] for r in rnas]) return dataset diff --git a/src/rnaglib/tasks/RNA_Family/rfam.py b/src/rnaglib/tasks/RNA_Family/rfam.py index 7d60305..b721f66 100644 --- a/src/rnaglib/tasks/RNA_Family/rfam.py +++ b/src/rnaglib/tasks/RNA_Family/rfam.py @@ -30,7 +30,7 @@ def __init__(self, root, max_size: int = 200, splitter=None, **kwargs): super().__init__(root=root, splitter=splitter, **kwargs) pass - def features_computer(self): + def get_task_vars(self): return FeaturesComputer( rna_targets=["rfam"], custom_encoders={"rfam": OneHotEncoder(self.metadata["label_mapping"])}, @@ -51,9 +51,6 @@ def process(self): rnas = ChainSplitTransform()(rnas) rnas = ChainNameTransform()(rnas) - ft = FeaturesComputer( - rna_targets=[tr_rfam.name], custom_encoders={tr_rfam.name: tr_rfam.encoder} - ) new_dataset = RNADataset(rnas=list((r["rna"] for r in rnas))) return new_dataset diff --git a/src/rnaglib/tasks/task.py b/src/rnaglib/tasks/task.py index f8d5a9c..9b11d57 100644 --- a/src/rnaglib/tasks/task.py +++ b/src/rnaglib/tasks/task.py @@ -57,7 +57,7 @@ def __init__( self.metadata = metadata self.dataset = dataset - self.dataset.features_computer = self.get_features_computer() + self.dataset.features_computer = self.get_task_vars() self.train_ind = train_ind self.val_ind = val_ind @@ -78,7 +78,8 @@ def init_metadata(self) -> dict: """Optionally adds some key/value pairs to self.metadata.""" return {} - def get_features_computer(self) -> FeaturesComputer: + @property + def get_task_vars(self) -> FeaturesComputer: """Define a FeaturesComputer object to set which input and output variables will be used in the task.""" return FeaturesComputer() diff --git a/src/rnaglib/transforms/filter/filters.py b/src/rnaglib/transforms/filter/filters.py index cf7b3b3..44ca69d 100644 --- a/src/rnaglib/transforms/filter/filters.py +++ b/src/rnaglib/transforms/filter/filters.py @@ -1,4 +1,4 @@ -from typing import Iterator +from typing import Iterator, Any, Callable import requests import networkx as nx @@ -9,29 +9,33 @@ desired conditione. """ + class SizeFilter(FilterTransform): - """ Reject RNAs that are not in the given size bounds. + """Reject RNAs that are not in the given size bounds. :param min_size: smallest allowed number of residues :param max_size: largest allowed number of residues. Default -1 which means no upper bound. """ - def __init__(self, min_size:int = 0, max_size: int = -1, **kwargs): + + def __init__(self, min_size: int = 0, max_size: int = -1, **kwargs): self.min_size = min_size self.max_size = max_size super().__init__(**kwargs) def forward(self, rna_dict: dict) -> bool: - n = len(rna_dict['rna'].nodes()) + n = len(rna_dict["rna"].nodes()) if self.max_size == -1: return n > self.min_size else: return n > self.min_size and n < self.max_size + class RNAAttributeFilter(FilterTransform): - """ Reject RNAs that lack a certain annotation at the whole RNA level. + """Reject RNAs that lack a certain annotation at the whole RNA level. :param attribute: which RNA-level attribute to look for. """ + def __init__(self, attribute: str, **kwargs): self.attribute = attribute super().__init__(**kwargs) @@ -39,71 +43,106 @@ def __init__(self, attribute: str, **kwargs): def forward(self, data: dict): try: - annot = data['rna'].graph[self.attribute] + annot = data["rna"].graph[self.attribute] except KeyError: return False else: if annot is None: return False return True + pass + class ResidueAttributeFilter(FilterTransform): - """ Reject RNAs that lack a certain annotation at the whole residue-level. + """Reject RNAs that lack a certain annotation at the whole residue-level. :param attribute: which node-level attribute to look for. + :param value_checker: function with accepts the value of the desired attribute and returns True/False :param min_valid: minium number of valid nodes that pass the filter for keeping the RNA. + + + Example + --------- + + Keep RNAs with at least 1 chemically modified residue:: + + >>> from rnaglib.data_loading import RNADataset + >>> from rnaglib.transforms import ResidueAttributeFilter + + >>> dset = RNADataset(debug=True) + >>> t = ResidueAttributeFilter(attribute='is_modified', + value_checker: lambda val: val == True, + min_valid=1) + >>> len(dset) + >>> rnas = list(t(dset)) + >>> len(rnas) + + """ - def __init__(self, attribute: str, min_valid: int = 1, **kwargs): + def __init__( + self, + attribute: str, + value_checker: Callable = None, + min_valid: int = 1, + **kwargs, + ): self.attribute = attribute self.min_valid = min_valid + self.value_checker = value_checker super().__init__(**kwargs) pass def forward(self, data: dict): n_valid = 0 - g = data['rna'] + g = data["rna"] for node, ndata in g.nodes(data=True): try: - annot = ndata[self.attribute] + val = ndata[self.attribute] except KeyError: continue else: - if annot is None: - continue - n_valid += 1 + if self.value_checker(val): + n_valid += 1 if n_valid >= self.min_valid: return True return False class RibosomalFilter(FilterTransform): - """ Remove RNA if ribosomal """ - ribosomal_keywords = ['ribosomal', 'rRNA', '50S', '30S', '60S', '40S'] + """Remove RNA if ribosomal""" + + ribosomal_keywords = ["ribosomal", "rRNA", "50S", "30S", "60S", "40S"] + def __init__(self, **kwargs): super().__init__(**kwargs) pass + def forward(self, data: dict): - pdbid = data['rna'].graph['pdbid'][0] + pdbid = data["rna"].graph["pdbid"][0] url = f"https://data.rcsb.org/rest/v1/core/entry/{pdbid}" response = requests.get(url) data = response.json() # Check title and description - title = data.get('struct', {}).get('title', '').lower() + title = data.get("struct", {}).get("title", "").lower() if any(keyword in title for keyword in self.ribosomal_keywords): return False # Check keywords - keywords = data.get('struct_keywords', {}).get('pdbx_keywords', '').lower() + keywords = data.get("struct_keywords", {}).get("pdbx_keywords", "").lower() if any(keyword in keywords for keyword in self.ribosomal_keywords): return False # Check polymer descriptions (for RNA and ribosomal proteins) - for polymer in data.get('polymer_entities', []): - description = polymer.get('rcsb_polymer_entity', {}).get('pdbx_description', '').lower() + for polymer in data.get("polymer_entities", []): + description = ( + polymer.get("rcsb_polymer_entity", {}) + .get("pdbx_description", "") + .lower() + ) if any(keyword in description for keyword in self.ribosomal_keywords): return False - + return True