From be73cde5bcb9ef2936cbdb018ab64e18979a98f3 Mon Sep 17 00:00:00 2001 From: "Colleen J. Gillon" Date: Thu, 1 Feb 2024 15:33:57 +0000 Subject: [PATCH] Added a warning if number of neurons will be changed in the initialization of a neuron layer. Adde an error in VectorCells init if class is initialized directly. Removed default value of None for `Other_Agent` in AgentVectorCells init, as None is not an allowed value. Raised an error in ObjectVectorCells if there are not objects in the environment. --- ratinabox/Neurons.py | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/ratinabox/Neurons.py b/ratinabox/Neurons.py index 45d37a1..9daeaf7 100644 --- a/ratinabox/Neurons.py +++ b/ratinabox/Neurons.py @@ -1209,6 +1209,10 @@ def __init__(self, Agent, params={}): Agent. The RatInABox Agent these cells belong to. params (dict, optional). Defaults to {}. """ + + if type(self) == VectorCells: + raise RuntimeError("Cannot instantiate VectorCells on their own. Must be instantiated through one of the subclasses, e.g., ObjectVectorCells, FieldOfView_BVCs, etc.") + assert ( self.Agent.Environment.dimensionality == "2D" ), "Vector cells only possible in 2D" @@ -1245,18 +1249,28 @@ def __init__(self, Agent, params={}): self.sigma_angles = None self.sigma_distances = None - # set the parameters of each cell. (self.tuning_distances, self.tuning_angles, self.sigma_distances, self.sigma_angles) = self.set_tuning_parameters(**self.params) - self.n = len(self.tuning_distances) #ensure n is correct + + # records whether n was passed as a parameter. + if not hasattr(self, "_warn_n_change"): + self._warn_n_change = ("n" in params.keys() and params["n"] is not None) + + # raises a warning if n was passed as a parameter, but will change. + if self._warn_n_change: + warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed, and setting number of {self.name} neurons to {len(self.tuning_distances)}, inferred from the cell arrangement parameter.") + + self.n = len(self.tuning_distances) # ensure n is correct + + self.firingrate = np.zeros(self.n) self.noise = np.zeros(self.n) self.cell_colors = None - + def set_tuning_parameters(self, **kwargs): """Get the tuning parameters for the vector cells. Args: @@ -1438,6 +1452,10 @@ def __init__(self, Agent, params={}): self.params = copy.deepcopy(__class__.default_params) self.params.update(params) + # records whether n was passed as a parameter. + if not hasattr(self, "_warn_n_change"): + self._warn_n_change = ("n" in params.keys() and params["n"] is not None) + super().__init__(Agent, self.params) assert ( @@ -1483,7 +1501,6 @@ def __init__(self, Agent, params={}): return - def get_state(self, evaluate_at="agent", **kwargs): """ Here we implement the same type if boundary vector cells as de Cothi et al. (2020), @@ -1789,9 +1806,15 @@ def __init__(self, Agent, params={}): self.Agent.Environment.dimensionality == "2D" ), "object vector cells only possible in 2D" + # records whether n was passed as a parameter. + if not hasattr(self, "_warn_n_change"): + self._warn_n_change = ("n" in params.keys() and params["n"] is not None) + super().__init__(Agent, self.params) self.object_locations = self.Agent.Environment.objects["objects"] + if len(self.object_locations) == 0: + raise RuntimeError(f"Cannot initialize {self.params['name']}, as there are no objects in the environment.") self.tuning_types = None @@ -2006,7 +2029,6 @@ def __init__(self,Agent,params={}): warnings.warn("For FieldOfViewOVCs you must specify the object type they are selective for with the 'object_tuning_type' parameter. This can be 'random' (each cell in the field of view chooses a random object type) or any integer (all cells have the same preference for this type). For now defaulting to params['object_tuning_type'] = 0.") self.params["object_tuning_type"] = 0 - self.params["reference_frame"] = "egocentric" assert self.params["cell_arrangement"] is not None, "cell_arrangement must be set for FOV Neurons" @@ -2034,7 +2056,7 @@ class AgentVectorCells(VectorCells): def __init__(self, Agent, - Other_Agent = None, #this must be another riab Agent object + Other_Agent, #this must be another riab Agent object params={}): self.Agent = Agent @@ -2269,6 +2291,8 @@ def __init__(self, Agent, params={}): [self.params["angular_spread_degrees"] * np.pi / 180] * self.n ) if self.Agent.Environment.dimensionality == "1D": + if "n" in params.keys() and params["n"] != 2: + warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed for {self.params['name']}. Only 2 head direction cells are needed for a 1D environment.") self.n = 2 # one left, one right self.params["n"] = self.n super().__init__(Agent, self.params) @@ -2476,7 +2500,11 @@ def __init__(self, Agent, params={}): self.params.update(params) super().__init__(Agent, self.params) + + if "n" in params.keys() and params["n"] != 1: + warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed for {self.name}. Only 1 speed cell is needed.") self.n = 1 + self.one_sigma_speed = self.Agent.speed_mean + self.Agent.speed_std if ratinabox.verbose is True: