diff --git a/mcfly/find_architecture.py b/mcfly/find_architecture.py index 488cbde..6b75b1c 100644 --- a/mcfly/find_architecture.py +++ b/mcfly/find_architecture.py @@ -75,7 +75,7 @@ def train_models_on_samples(X_train, y_train, X_val, y_val, models, or `(inputs, targets, sample_weights)` - generator or `keras.utils.Sequence`. Should return a tuple of `(inputs, targets)` or `(inputs, targets, sample_weights)` - + The input dataset for validation of shape (num_samples_val, num_timesteps, num_channels) More details can be found in the documentation for the Keras @@ -185,7 +185,7 @@ def train_models_on_samples(X_train, y_train, X_val, y_val, models, if outputfile is not None: store_train_hist_as_json(params, model_types, history.history, - outputfile, model.metrics_names[0]) + outputfile) if model_path is not None: model.save(os.path.join(model_path, 'model_{}.h5'.format(i))) diff --git a/tests/test_integration.py b/tests/test_integration.py index b822138..de0b085 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -19,7 +19,7 @@ def test_integration(self): number_of_classes=num_classes, number_of_models=2, metrics=[metric], - model_type='CNN') # Because CNNs are quick to train. + model_types=['CNN']) # Because CNNs are quick to train. histories, val_accuracies, val_losses = find_architecture.train_models_on_samples(X_train, y_train, X_val, y_val, models, nr_epochs=5, diff --git a/tests/test_storage.py b/tests/test_storage.py index f2c104c..e1084b9 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -53,7 +53,10 @@ def create_dummy_model(): num_time_steps = 100 num_channels = 2 num_samples_train = 5 - model, _parameters, _type = modelgen.generate_models((num_samples_train, num_time_steps, num_channels), 5, 1)[0] + model, _parameters, _type = modelgen.generate_models( + (num_samples_train, num_time_steps, num_channels), 5, 1, + ['CNN'] # Chosen one model type to avoid warnings + )[0] return model