Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes and new warning for Neurons classes. #101

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions ratinabox/Neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,10 @@ def __init__(self, Agent, params={}):
Agent. The RatInABox Agent these cells belong to.
params (dict, optional). Defaults to {}.
"""

if type(self) == VectorCells:
raise RuntimeError("Cannot instantiate VectorCells on their own. Must be instantiated through one of the subclasses, e.g., ObjectVectorCells, FieldOfView_BVCs, etc.")

assert (
self.Agent.Environment.dimensionality == "2D"
), "Vector cells only possible in 2D"
Expand Down Expand Up @@ -1245,18 +1249,28 @@ def __init__(self, Agent, params={}):
self.sigma_angles = None
self.sigma_distances = None


# set the parameters of each cell.
(self.tuning_distances,
self.tuning_angles,
self.sigma_distances,
self.sigma_angles) = self.set_tuning_parameters(**self.params)
self.n = len(self.tuning_distances) #ensure n is correct

# records whether n was passed as a parameter.
if not hasattr(self, "_warn_n_change"):
self._warn_n_change = ("n" in params.keys() and params["n"] is not None)

# raises a warning if n was passed as a parameter, but will change.
if self._warn_n_change:
warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed, and setting number of {self.name} neurons to {len(self.tuning_distances)}, inferred from the cell arrangement parameter.")

self.n = len(self.tuning_distances) # ensure n is correct


self.firingrate = np.zeros(self.n)
self.noise = np.zeros(self.n)
self.cell_colors = None


def set_tuning_parameters(self, **kwargs):
"""Get the tuning parameters for the vector cells.
Args:
Expand Down Expand Up @@ -1438,6 +1452,10 @@ def __init__(self, Agent, params={}):
self.params = copy.deepcopy(__class__.default_params)
self.params.update(params)

# records whether n was passed as a parameter.
if not hasattr(self, "_warn_n_change"):
self._warn_n_change = ("n" in params.keys() and params["n"] is not None)

super().__init__(Agent, self.params)

assert (
Expand Down Expand Up @@ -1483,7 +1501,6 @@ def __init__(self, Agent, params={}):
return



def get_state(self, evaluate_at="agent", **kwargs):
"""
Here we implement the same type if boundary vector cells as de Cothi et al. (2020),
Expand Down Expand Up @@ -1789,9 +1806,15 @@ def __init__(self, Agent, params={}):
self.Agent.Environment.dimensionality == "2D"
), "object vector cells only possible in 2D"

# records whether n was passed as a parameter.
if not hasattr(self, "_warn_n_change"):
self._warn_n_change = ("n" in params.keys() and params["n"] is not None)

super().__init__(Agent, self.params)

self.object_locations = self.Agent.Environment.objects["objects"]
if len(self.object_locations) == 0:
raise RuntimeError(f"Cannot initialize {self.params['name']}, as there are no objects in the environment.")

self.tuning_types = None

Expand Down Expand Up @@ -2006,7 +2029,6 @@ def __init__(self,Agent,params={}):
warnings.warn("For FieldOfViewOVCs you must specify the object type they are selective for with the 'object_tuning_type' parameter. This can be 'random' (each cell in the field of view chooses a random object type) or any integer (all cells have the same preference for this type). For now defaulting to params['object_tuning_type'] = 0.")
self.params["object_tuning_type"] = 0


self.params["reference_frame"] = "egocentric"
assert self.params["cell_arrangement"] is not None, "cell_arrangement must be set for FOV Neurons"

Expand Down Expand Up @@ -2034,7 +2056,7 @@ class AgentVectorCells(VectorCells):

def __init__(self,
Agent,
Other_Agent = None, #this must be another riab Agent object
Other_Agent, #this must be another riab Agent object
params={}):

self.Agent = Agent
Expand Down Expand Up @@ -2269,6 +2291,8 @@ def __init__(self, Agent, params={}):
[self.params["angular_spread_degrees"] * np.pi / 180] * self.n
)
if self.Agent.Environment.dimensionality == "1D":
if "n" in params.keys() and params["n"] != 2:
warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed for {self.params['name']}. Only 2 head direction cells are needed for a 1D environment.")
self.n = 2 # one left, one right
self.params["n"] = self.n
super().__init__(Agent, self.params)
Expand Down Expand Up @@ -2476,7 +2500,11 @@ def __init__(self, Agent, params={}):
self.params.update(params)

super().__init__(Agent, self.params)

if "n" in params.keys() and params["n"] != 1:
warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed for {self.name}. Only 1 speed cell is needed.")
self.n = 1

self.one_sigma_speed = self.Agent.speed_mean + self.Agent.speed_std

if ratinabox.verbose is True:
Expand Down
Loading