Skip to content

Commit

Permalink
Fixed mistake I introduced in PR RatInABox-Lab#101, raised in issue R…
Browse files Browse the repository at this point in the history
…atInABox-Lab#102.

Reintroduced an equality check for raising a warning if value of self.n was passed, but is reset internally in a VectorCells class.
Equality is checked if `cell_arrangement` does not end in `"manifold"` (i.e., `"diverging_manifold"` or `"uniform_manifold"`.
If `cell_arrangement` is `"random"` or a callable function, a warning is only raised if the value of self.n would actually change.
  • Loading branch information
colleenjg committed Feb 2, 2024
1 parent 259d072 commit d5cf9eb
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions ratinabox/Neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,12 +1256,14 @@ def __init__(self, Agent, params={}):
self.sigma_angles) = self.set_tuning_parameters(**self.params)

# records whether n was passed as a parameter.
if not hasattr(self, "_warn_n_change"):
self._warn_n_change = ("n" in params.keys() and params["n"] is not None)
if not hasattr(self, "_warn_if_n_changes"):
self._warn_if_n_changes = ("n" in params.keys() and params["n"] is not None)

# raises a warning if n was passed as a parameter, but will change.
if self._warn_n_change:
warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed, and setting number of {self.name} neurons to {len(self.tuning_distances)}, inferred from the cell arrangement parameter.")
if self._warn_if_n_changes:
dont_check_equality = (isinstance(self.params["cell_arrangement"], str) and self.params["cell_arrangement"].endswith("manifold"))
if dont_check_equality or self.n != len(self.tuning_distances):
warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed, and setting number of {self.name} neurons to {len(self.tuning_distances)}, inferred from the cell arrangement parameter.")

self.n = len(self.tuning_distances) # ensure n is correct

Expand Down Expand Up @@ -1453,8 +1455,8 @@ def __init__(self, Agent, params={}):
self.params.update(params)

# records whether n was passed as a parameter.
if not hasattr(self, "_warn_n_change"):
self._warn_n_change = ("n" in params.keys() and params["n"] is not None)
if not hasattr(self, "_warn_if_n_changes"):
self._warn_if_n_changes = ("n" in params.keys() and params["n"] is not None)

super().__init__(Agent, self.params)

Expand Down Expand Up @@ -1807,8 +1809,8 @@ def __init__(self, Agent, params={}):
), "object vector cells only possible in 2D"

# records whether n was passed as a parameter.
if not hasattr(self, "_warn_n_change"):
self._warn_n_change = ("n" in params.keys() and params["n"] is not None)
if not hasattr(self, "_warn_if_n_changes"):
self._warn_if_n_changes = ("n" in params.keys() and params["n"] is not None)

super().__init__(Agent, self.params)

Expand Down

0 comments on commit d5cf9eb

Please sign in to comment.