Skip to content

Commit

Permalink
Allow passing multiple name for registering a value in RegistryMixin (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz authored and bfineran committed Nov 16, 2023
1 parent f269c02 commit 178cbf5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
36 changes: 23 additions & 13 deletions src/sparsezoo/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import importlib
from collections import defaultdict
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, Union


__all__ = [
Expand Down Expand Up @@ -64,6 +64,11 @@ class ImageNetDataset(Dataset):
class Cifar(Dataset):
pass
# register with multiple aliases
@Dataset.register(name=["cifar-10-dataset", "cifar-100-dataset"])
class Cifar(Dataset):
pass
# load as "cifar-dataset"
cifar = Dataset.load_from_registry("cifar-dataset")
Expand All @@ -77,12 +82,13 @@ class Cifar(Dataset):
registry_requires_subclass: bool = False

@classmethod
def register(cls, name: Optional[str] = None):
def register(cls, name: Union[List[str], str, None] = None):
"""
Decorator for registering a value (ie class or function) wrapped by this
decorator to the base class (class that .register is called from)
:param name: name to register the wrapped value as, defaults to value.__name__
:param name: name or list of names to register the wrapped value as,
defaults to value.__name__
:return: register decorator
"""

Expand All @@ -93,18 +99,22 @@ def decorator(value: Any):
return decorator

@classmethod
def register_value(cls, value: Any, name: Optional[str] = None):
def register_value(cls, value: Any, name: Union[List[str], str, None] = None):
"""
Registers the given value to the class `.register_value` is called from
:param value: value to register
:param name: name to register the wrapped value as, defaults to value.__name__
:param name: name or list of names to register the wrapped value as,
defaults to value.__name__
"""
register(
parent_class=cls,
value=value,
name=name,
require_subclass=cls.registry_requires_subclass,
)
names = name if isinstance(name, list) else [name]

for name in names:
register(
parent_class=cls,
value=value,
name=name,
require_subclass=cls.registry_requires_subclass,
)

@classmethod
def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
Expand Down Expand Up @@ -148,7 +158,7 @@ def register(
):
"""
:param parent_class: class to register the name under
:param value: value to register
:param value: the value to register
:param name: name to register the wrapped value as, defaults to value.__name__
:param require_subclass: require that value is a subclass of the class this
method is called from
Expand Down Expand Up @@ -193,7 +203,7 @@ def get_from_registry(
# look up name in registry
retrieved_value = _REGISTRY[parent_class].get(name)
if retrieved_value is None:
raise ValueError(
raise KeyError(
f"Unable to find {name} registered under type {parent_class}. "
f"Registered values for {parent_class}: "
f"{registered_names(parent_class)}"
Expand Down
15 changes: 14 additions & 1 deletion tests/sparsezoo/utils/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,30 @@ class Foo(RegistryMixin):
class Foo1(Foo):
pass

assert {"Foo1"} == set(Foo.registered_names())

@Foo.register(name="name_2")
class Foo2(Foo):
pass

assert {"Foo1", "name_2"} == set(Foo.registered_names())

with pytest.raises(ValueError):
@Foo.register(name=["name_3", "name_4"])
class Foo3(Foo):
pass

assert {"Foo1", "name_2", "name_3", "name_4"} == set(Foo.registered_names())

with pytest.raises(KeyError):
Foo.get_value_from_registry("Foo2")

assert Foo.get_value_from_registry("Foo1") is Foo1
assert isinstance(Foo.load_from_registry("name_2"), Foo2)
assert (
Foo.get_value_from_registry("name_3")
is Foo3
is Foo.get_value_from_registry("name_4")
)


def test_registry_flow_multiple():
Expand Down

0 comments on commit 178cbf5

Please sign in to comment.