From d5cf9ebcf61eb4cd4774ef6fae62fbfc9aaa1c66 Mon Sep 17 00:00:00 2001 From: "Colleen J. Gillon" Date: Fri, 2 Feb 2024 11:03:14 +0000 Subject: [PATCH] Fixed mistake I introduced in PR #101, raised in issue #102. Reintroduced an equality check for raising a warning if value of self.n was passed, but is reset internally in a VectorCells class. Equality is checked if `cell_arrangement` does not end in `"manifold"` (i.e., `"diverging_manifold"` or `"uniform_manifold"`. If `cell_arrangement` is `"random"` or a callable function, a warning is only raised if the value of self.n would actually change. --- ratinabox/Neurons.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ratinabox/Neurons.py b/ratinabox/Neurons.py index 9daeaf7..9d66dfe 100644 --- a/ratinabox/Neurons.py +++ b/ratinabox/Neurons.py @@ -1256,12 +1256,14 @@ def __init__(self, Agent, params={}): self.sigma_angles) = self.set_tuning_parameters(**self.params) # records whether n was passed as a parameter. - if not hasattr(self, "_warn_n_change"): - self._warn_n_change = ("n" in params.keys() and params["n"] is not None) + if not hasattr(self, "_warn_if_n_changes"): + self._warn_if_n_changes = ("n" in params.keys() and params["n"] is not None) # raises a warning if n was passed as a parameter, but will change. - if self._warn_n_change: - warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed, and setting number of {self.name} neurons to {len(self.tuning_distances)}, inferred from the cell arrangement parameter.") + if self._warn_if_n_changes: + dont_check_equality = (isinstance(self.params["cell_arrangement"], str) and self.params["cell_arrangement"].endswith("manifold")) + if dont_check_equality or self.n != len(self.tuning_distances): + warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed, and setting number of {self.name} neurons to {len(self.tuning_distances)}, inferred from the cell arrangement parameter.") self.n = len(self.tuning_distances) # ensure n is correct @@ -1453,8 +1455,8 @@ def __init__(self, Agent, params={}): self.params.update(params) # records whether n was passed as a parameter. - if not hasattr(self, "_warn_n_change"): - self._warn_n_change = ("n" in params.keys() and params["n"] is not None) + if not hasattr(self, "_warn_if_n_changes"): + self._warn_if_n_changes = ("n" in params.keys() and params["n"] is not None) super().__init__(Agent, self.params) @@ -1807,8 +1809,8 @@ def __init__(self, Agent, params={}): ), "object vector cells only possible in 2D" # records whether n was passed as a parameter. - if not hasattr(self, "_warn_n_change"): - self._warn_n_change = ("n" in params.keys() and params["n"] is not None) + if not hasattr(self, "_warn_if_n_changes"): + self._warn_if_n_changes = ("n" in params.keys() and params["n"] is not None) super().__init__(Agent, self.params)