Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "task_name_field" #109

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 43 additions & 57 deletions mttl/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,25 @@ class MultiDefaultValue:
different defaults for the same field. This class is used to store all the defaults and resolve them when needed.
"""

def __init__(self, cls: dataclass, name: str, field_type: Type[T], default: T):
self.name = name
def __init__(self, field_type: Type[T]):
self.type = field_type
self.defaults: Dict[str, T] = {cls.__name__: default}
self.defaults: Dict[str, T] = {}

def update(self, cls, default, field_type):
def add_default(self, klass, value, field_type):
if field_type != self.type:
raise TypeError(
f"Field '{self.name}' has conflicting types: {field_type} and {self.type}."
f"Field has conflicting types: {field_type} and {self.type}."
)

# Add a new default only if it's different from the last one
last_default = list(self.defaults.values())[-1]
if default != last_default:
self.defaults[cls.__name__] = default
if value not in set(self.defaults.values()):
self.defaults[klass.__name__] = value

def resolve_default(self):
# if any of these attributes is required, we need to specify it in the config
return next(iter(self.defaults.values())) if len(self.defaults) == 1 else self
def resolve(self):
if len(self.defaults) == 1:
return list(self.defaults.values())[0]
return self

def __repr__(self):
if len(self.defaults) == 1:
return repr(self.default)
return f"MultiDefaultValue({self.defaults})"


Expand All @@ -65,23 +61,27 @@ def dataclasses_union(*dataclasses: Type[dataclass]) -> List:
)

if field_.name not in new_fields:
multi_default = MultiDefaultValue(
klass, name, field_.type, field_.default
)
new_fields[name] = (field_.type, field(default=multi_default))
# If the field does not exist, we create it
multi_default = MultiDefaultValue(field_.type)
else:
new_fields[name][1].default.update(klass, field_.default, field_.type)
multi_default = new_fields[name][1].default

multi_default.add_default(klass, field_.default, field_.type)
new_fields[name] = (multi_default.type, field(default=multi_default))

# We resolve the default if possible for each field, if we can't resolve default will be MultiDefaultValue
for k, (_, field_instance) in new_fields.items():
field_instance.default = field_instance.default.resolve_default()
# try to resolve the MultiDefaultValue objects
for name, (field_type, field_info) in new_fields.items():
multi_default = field_info.default
if isinstance(multi_default, MultiDefaultValue):
new_fields[name] = (field_type, field(default=multi_default.resolve()))

return [(name,) + field_info for name, field_info in new_fields.items()]


def create_config_class_from_args(config_class, args):
"""
Load a dataclass from the arguments.
Load a dataclass from the arguments. We don't include field names that were not set,
i.e. that were MultiDefaultValue.
"""
kwargs = {
f.name: getattr(args, f.name)
Expand Down Expand Up @@ -263,60 +263,46 @@ def fromdict(cls, data: Dict):
return cls.fromdict_legacy(data)


class MetaRegistrable(type):
"""
Meta class that creates a new dataclass containing fields all the config dataclasses
in this registrable.
"""

def __new__(cls, name, bases, attrs, registrable: Registrable = None):
def dataclass_from_registrable(registrable):
def decorator(cls):
module_name, class_name = registrable.rsplit(".", 1)
module = importlib.import_module(module_name)

registrable_class = getattr(module, class_name)

# make the union of all the fields across the registered configs
# Make the union of all the fields across the registered configs
to_tuples = dataclasses_union(*registrable_class.registered_configs())

# create new dataclass with the union of all the fields
new_cls = make_dataclass(name, to_tuples, bases=(Args,), init=False)
# Create new dataclass with the union of all the fields
new_cls = make_dataclass(cls.__name__, to_tuples, bases=(Args,), init=False)
new_cls.registrable_class = registrable_class

# set functions to be had in the new baby dataclass
for k, v in {
**attrs,
}.items():
setattr(new_cls, k, v)
# Copy attributes and methods from the original class
for k, v in cls.__dict__.items():
if not k.startswith("__"):
setattr(new_cls, k, v)

return new_cls

return decorator

@dataclass
class DataArgs(
metaclass=MetaRegistrable, registrable="mttl.datamodule.base.DataModule"
):

@dataclass_from_registrable("mttl.datamodule.base.DataModule")
class DataArgs(Args):
pass


@dataclass
class SelectorArgs(
metaclass=MetaRegistrable,
registrable="mttl.models.containers.selectors.base.Selector",
):
@dataclass_from_registrable("mttl.models.containers.selectors.base.Selector")
class SelectorArgs(Args):
pass


@dataclass
class ModifierArgs(
metaclass=MetaRegistrable, registrable="mttl.models.modifiers.base.Modifier"
):
@dataclass_from_registrable("mttl.models.modifiers.base.Modifier")
class ModifierArgs(Args):
pass


@dataclass
class TransformArgs(
metaclass=MetaRegistrable,
registrable="mttl.models.library.library_transforms.LibraryTransform",
):
@dataclass_from_registrable("mttl.models.library.library_transforms.LibraryTransform")
class TransformArgs(Args):
pass


Expand Down
Loading
Loading