From ba35c38a9384285ce33407610a0d738ed182679c Mon Sep 17 00:00:00 2001 From: Quinten Date: Sun, 14 Apr 2024 22:59:04 -0700 Subject: [PATCH] move hyperparameters to config --- src/revnets/models/config.py | 4 +++- .../reconstructions/queries/iterative/difficult_inputs.py | 8 +++++--- tests/conftest.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/revnets/models/config.py b/src/revnets/models/config.py index a0c6453..9d0bec2 100644 --- a/src/revnets/models/config.py +++ b/src/revnets/models/config.py @@ -47,6 +47,9 @@ class Config(SerializationMixin): target_network_training: HyperParameters = HyperParameters( epochs=100, learning_rate=1.0e-2, batch_size=32 ) + difficult_inputs_training: HyperParameters = HyperParameters( + epochs=1000, learning_rate=1.0e-3 + ) evaluation: Evaluation = field(default_factory=Evaluation) evaluation_batch_size: int = 1000 @@ -68,7 +71,6 @@ class Config(SerializationMixin): validation_ratio: float = 0.1 console_metrics_refresh_interval: float = 0.5 - max_difficult_inputs_epochs: int = 1000 limit_batches: int | None = None diff --git a/src/revnets/reconstructions/queries/iterative/difficult_inputs.py b/src/revnets/reconstructions/queries/iterative/difficult_inputs.py index 8582619..82a1a22 100644 --- a/src/revnets/reconstructions/queries/iterative/difficult_inputs.py +++ b/src/revnets/reconstructions/queries/iterative/difficult_inputs.py @@ -20,11 +20,13 @@ def __init__( self, shape: tuple[int, ...], reconstructions: list[torch.nn.Sequential], - learning_rate: float = 0.001, + learning_rate: float | None = None, verbose: bool = True, ) -> None: super().__init__() self.shape = shape + if learning_rate is None: + learning_rate = context.config.difficult_inputs_training.learning_rate self.learning_rate = learning_rate self.inputs_embedding = self.create_input_embeddings(shape) self.reconstructions = torch.nn.ModuleList(reconstructions) @@ -101,8 +103,8 @@ def create_difficult_samples(self) -> torch.Tensor: @classmethod def fit_inputs_network(cls, network: InputNetwork) -> None: - max_epochs = context.config.max_difficult_inputs_epochs - trainer = Trainer(max_epochs=max_epochs, log_every_n_steps=1) + epochs = context.config.difficult_inputs_training.epochs + trainer = Trainer(max_epochs=epochs, log_every_n_steps=1) dataset = EmptyDataset() dataloader = DataLoader(dataset) trainer.fit(network, dataloader) diff --git a/tests/conftest.py b/tests/conftest.py index 174b9c0..c164f87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,7 +32,7 @@ def test_context(context: Context, mocked_assets_path: None) -> Iterator[Context context.loaders.config.value = Config( target_network_training=hyperparameters, reconstruction_training=hyperparameters, - max_difficult_inputs_epochs=1, + difficult_inputs_training=hyperparameters, evaluation=evaluation, limit_batches=5, )