From 08c07e5510cb785d47f579638dc86c7f9e145a49 Mon Sep 17 00:00:00 2001 From: breixo Date: Fri, 31 Jul 2020 12:42:08 +0200 Subject: [PATCH 1/3] Removed metrics as they are deprecated --- mcfly/find_architecture.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mcfly/find_architecture.py b/mcfly/find_architecture.py index 488cbde..6b75b1c 100644 --- a/mcfly/find_architecture.py +++ b/mcfly/find_architecture.py @@ -75,7 +75,7 @@ def train_models_on_samples(X_train, y_train, X_val, y_val, models, or `(inputs, targets, sample_weights)` - generator or `keras.utils.Sequence`. Should return a tuple of `(inputs, targets)` or `(inputs, targets, sample_weights)` - + The input dataset for validation of shape (num_samples_val, num_timesteps, num_channels) More details can be found in the documentation for the Keras @@ -185,7 +185,7 @@ def train_models_on_samples(X_train, y_train, X_val, y_val, models, if outputfile is not None: store_train_hist_as_json(params, model_types, history.history, - outputfile, model.metrics_names[0]) + outputfile) if model_path is not None: model.save(os.path.join(model_path, 'model_{}.h5'.format(i))) From 5734caedbc93463c5d8a157b7b7eaddc9b0fc816 Mon Sep 17 00:00:00 2001 From: breixo Date: Fri, 31 Jul 2020 12:52:35 +0200 Subject: [PATCH 2/3] Updated model selection --- tests/test_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_integration.py b/tests/test_integration.py index b822138..de0b085 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -19,7 +19,7 @@ def test_integration(self): number_of_classes=num_classes, number_of_models=2, metrics=[metric], - model_type='CNN') # Because CNNs are quick to train. + model_types=['CNN']) # Because CNNs are quick to train. histories, val_accuracies, val_losses = find_architecture.train_models_on_samples(X_train, y_train, X_val, y_val, models, nr_epochs=5, From a259f85d50f647fb748e20b54f6eaf010983fd9b Mon Sep 17 00:00:00 2001 From: breixo Date: Fri, 31 Jul 2020 12:56:30 +0200 Subject: [PATCH 3/3] Restricted the model type to avoid warnings --- tests/test_storage.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_storage.py b/tests/test_storage.py index f2c104c..e1084b9 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -53,7 +53,10 @@ def create_dummy_model(): num_time_steps = 100 num_channels = 2 num_samples_train = 5 - model, _parameters, _type = modelgen.generate_models((num_samples_train, num_time_steps, num_channels), 5, 1)[0] + model, _parameters, _type = modelgen.generate_models( + (num_samples_train, num_time_steps, num_channels), 5, 1, + ['CNN'] # Chosen one model type to avoid warnings + )[0] return model