Skip to content

Commit

Permalink
python<=3.8 bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
TomGeorge1234 committed Mar 13, 2024
1 parent 383a821 commit 1390c86
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 30 deletions.
28 changes: 16 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# RatInABox ![Tests](https://github.com/RatInABox-Lab/RatInABox/actions/workflows/test.yml/badge.svg) [![PyPI version](https://badge.fury.io/py/ratinabox.svg)](https://badge.fury.io/py/ratinabox) [![Downloads](https://static.pepy.tech/badge/ratinabox)](https://pepy.tech/project/ratinabox)<img align="right" src=".images/readme/logo.png" width=150>

`RatInABox` (see [paper](https://www.biorxiv.org/content/10.1101/2022.08.10.503541v5)) is a toolkit for generating synthetic behaviour and neural data for spatially and/or velocity selective cell types in complex continuous environments.
`RatInABox` (see [paper](https://elifesciences.org/articles/85274)) is a toolkit for generating synthetic behaviour and neural data for spatially and/or velocity selective cell types in complex continuous environments.

[**Install**](#installing-and-importing) | [**Demos**](#get-started) | [**Features**](#feature-run-down) | [**Contributions and Questions**](#contribute) | [**Cite**](#cite)

Expand Down Expand Up @@ -443,26 +443,30 @@ Questions? Just ask! Ideally via opening an issue so others can see the answer t
Thanks to all contributors so far:
![GitHub Contributors Image](https://contrib.rocks/image?repo=RatInABox-Lab/RatInABox)

## Cite [![](http://img.shields.io/badge/bioRxiv-10.1101/2022.08.10.503541-B31B1B.svg)](https://doi.org/10.1101/2022.08.10.503541)
## Cite

If you use `RatInABox` in your research or educational material, please cite the work as follows:

Bibtex:
```
@article{ratinabox2022,
doi = {10.1101/2022.08.10.503541},
url = {https://doi.org/10.1101%2F2022.08.10.503541},
year = 2022,
month = {aug},
publisher = {Cold Spring Harbor Laboratory},
author = {Tom M George and William de Cothi and Claudia Clopath and Kimberly Stachenfeld and Caswell Barry},
title = {{RatInABox}: A toolkit for modelling locomotion and neuronal activity in continuous environments}
@article{George2024,
title = {RatInABox, a toolkit for modelling locomotion and neuronal activity in continuous environments},
volume = {13},
ISSN = {2050-084X},
url = {http://dx.doi.org/10.7554/eLife.85274},
DOI = {10.7554/elife.85274},
journal = {eLife},
publisher = {eLife Sciences Publications, Ltd},
author = {George, Tom M and Rastogi, Mehul and de Cothi, William and Clopath, Claudia and Stachenfeld, Kimberly and Barry, Caswell},
year = {2024},
month = feb
}
```

Formatted:
```
Tom M George, William de Cothi, Claudia Clopath, Kimberly Stachenfeld, Caswell Barry. "RatInABox: A toolkit for modelling locomotion and neuronal activity in continuous environments" (2022).
Tom M George, Mehul Rastogi, William de Cothi, Claudia Clopath, Kimberly Stachenfeld, Caswell Barry. "RatInABox, a toolkit for modelling locomotion and neuronal activity in continuous environments" (2024), eLife, https://doi.org/10.7554/eLife.85274 .
```
The research paper corresponding to the above citation can be found [here](https://www.biorxiv.org/content/10.1101/2022.08.10.503541v4).
The research paper corresponding to the above citation can be found [here](https://elifesciences.org/articles/85274).


14 changes: 7 additions & 7 deletions ratinabox/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


import warnings
from typing import Union
from typing import Union, List

from ratinabox import utils
from ratinabox.Agent import Agent
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self, params={}):
utils.update_class_params(self, self.params, get_all_defaults=True)
utils.check_params(self, params.keys())

self.Agents : list[Agent] = [] # each new Agent will append itself to this list
self.Agents : List[Agent] = [] # each new Agent will append itself to this list
self.agents_dict = {} # this is a dictionary which allows you to lookup a agent by name

if self.dimensionality == "1D":
Expand Down Expand Up @@ -206,17 +206,17 @@ def get_all_default_params(cls, verbose=False):
return all_default_params


def agent_lookup(self, agent_names:Union[str, list[str]] = None) -> list[Agent]:
def agent_lookup(self, agent_names:Union[str, List[str]] = None) -> List[Agent]:
'''
This function will lookup a agent by name and return it. This assumes that the agent has been
added to the Environment.agents list and that each agent object has a unique name associated with it.
Args:
agent_names (str, list[str]): the name of the agent you want to lookup.
agent_names (str, List[str]): the name of the agent you want to lookup.
Returns:
agents (list[Agent]): a list of agents that match the agent_names. If agent_names is a string, then a list of length 1 is returned. If agent_names is None, then None is returned
agents (List[Agent]): a list of agents that match the agent_names. If agent_names is a string, then a list of length 1 is returned. If agent_names is None, then None is returned
'''

Expand All @@ -226,7 +226,7 @@ def agent_lookup(self, agent_names:Union[str, list[str]] = None) -> list[Agent]
if isinstance(agent_names, str):
agent_names = [agent_names]

agents: list[Agent] = []
agents: List[Agent] = []

for agent_name in agent_names:
agent = self._agent_lookup(agent_name)
Expand Down Expand Up @@ -846,7 +846,7 @@ def apply_boundary_conditions(self, pos):
returns new_pos
TODO update this so if pos is in one of the holes the Agent is returned to the ~nearest legal location inside the Environment
"""
if self.check_if_position_is_in_environment(pos) is True: return
if self.check_if_position_is_in_environment(pos) is True: return pos

if self.dimensionality == "1D":
if self.boundary_conditions == "periodic":
Expand Down
19 changes: 14 additions & 5 deletions ratinabox/Neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,9 @@ def plot_rate_map(
interpolation="bicubic", # smooths rate maps but this does slow down the plotting a bit
)
elif method == "history":
bin_size = kwargs.get("bin_size", 0.05)
default_2D_bin_size = 0.05
bin_size = kwargs.get("bin_size", default_2D_bin_size)
print(f"Using bin size of {bin_size} for rate map calculation")
rate_timeseries_ = rate_timeseries[chosen_neurons[i], :]
rate_map, zero_bins = utils.bin_data_for_histogramming(
data=pos,
Expand Down Expand Up @@ -537,25 +539,32 @@ def plot_rate_map(

# PLOT 1D
elif self.Agent.Environment.dimensionality == "1D":
zero_bins = None
if method == "groundtruth":
rate_maps = rate_maps[chosen_neurons, :]
x = self.Agent.Environment.flattened_discrete_coords[:, 0]
if method == "history":
ex = self.Agent.Environment.extent
default_1D_bin_size = 0.01
bin_size = kwargs.get("bin_size", default_1D_bin_size)
pos_ = pos[:, 0]
rate_maps = []
for neuron_id in chosen_neurons:
rate_map, x = utils.bin_data_for_histogramming(
(rate_map, x, zero_bins) = utils.bin_data_for_histogramming(
data=pos_,
extent=ex,
dx=0.01,
dx=bin_size,
weights=rate_timeseries[neuron_id, :],
norm_by_bincount=True,
return_zero_bins=True,
)
x, rate_map = utils.interpolate_and_smooth(x, rate_map, sigma=0.03)
resolution_increase = 10
x, rate_map = utils.interpolate_and_smooth(x, rate_map, sigma=0.01, resolution_increase=resolution_increase)
rate_maps.append(rate_map)
zero_bins = np.repeat(zero_bins, resolution_increase)
rate_maps = np.array(rate_maps)


if fig is None and ax is None:
fig, ax = plt.subplots(
figsize=(
Expand All @@ -569,7 +578,7 @@ def plot_rate_map(

if method != "neither":
fig, ax = utils.mountain_plot(
X=x, NbyX=rate_maps, color=self.color, fig=fig, ax=ax, **kwargs
X=x, NbyX=rate_maps, color=self.color, nan_bins=zero_bins, fig=fig, ax=ax, **kwargs
)

if spikes is True:
Expand Down
22 changes: 16 additions & 6 deletions ratinabox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,20 +368,22 @@ def ornstein_uhlenbeck(dt, x, drift=0.0, noise_scale=0.2, coherence_time=5.0):
return dx


def interpolate_and_smooth(x, y, sigma=None):
def interpolate_and_smooth(x, y, sigma=None, resolution_increase=10):
"""Interpolates with cublic spline x and y to 10x resolution then smooths these with a gaussian kernel of width sigma.
Currently this only works for 1-dimensional x.
Args:
x
y
sigma
resolution_increase
Returns (x_new,y_new)
"""
from scipy.ndimage.filters import gaussian_filter1d
from scipy.interpolate import interp1d

y_cubic = interp1d(x, y, kind="cubic")
x_new = np.arange(x[0], x[-1], (x[1] - x[0]) / 10)
# x_new = np.arange(x[0], x[-1], (x[1] - x[0]) / resolution_increase)
x_new = np.linspace(x[0], x[-1], len(x) * resolution_increase)
y_interpolated = y_cubic(x_new)
if sigma is not None:
y_smoothed = gaussian_filter1d(
Expand Down Expand Up @@ -541,16 +543,20 @@ def bin_data_for_histogramming(data, extent, dx, weights=None, norm_by_bincount=
Returns:
(heatmap,bin_centres): if 1D
(heatmap): if 2D
(heatmap): if 2D --> you should be able ot infer the bin centres from the extent and dx you passed
in either case if return_zero_bins is True, the zero_bins array is also returned as the last element of the tuple
"""
if len(extent) == 2: # dimensionality = "1D"
bins = np.arange(extent[0], extent[1] + dx, dx)
heatmap, xedges = np.histogram(data, bins=bins, weights=weights)
if norm_by_bincount:
bincount = np.histogram(data, bins=bins)[0]
bincount[bincount == 0] = 1
zero_bins = (bincount == 0)
bincount[zero_bins] = 1
heatmap = heatmap / bincount
centres = (xedges[1:] + xedges[:-1]) / 2
if return_zero_bins:
return (heatmap, centres, zero_bins)
return (heatmap, centres)

elif len(extent) == 4: # dimensionality = "2D"
Expand Down Expand Up @@ -578,6 +584,7 @@ def mountain_plot(
xlabel="",
ylabel="",
xlim=None,
nan_bins=None,
fig=None,
ax=None,
norm_by="max",
Expand All @@ -599,6 +606,7 @@ def mountain_plot(
xlabel (str, optional): x axis label. Defaults to "".
ylabel (str, optional): y axis label. Defaults to "".
xlim (_type_, optional): fix xlim to this is desired. Defaults to None.
nan_bins (array, optional): Optionally pass a boolean array of the same shape as X which is True where you want to plot a gap in the mountain plot. Defaults to None (ie skipped).
fig (_type_, optional): fig to plot over if desired. Defaults to None.
ax (_type_, optional): ax to plot on if desider. Defaults to None.
norm_by: what to normalise each line of the mountainplot by.
Expand Down Expand Up @@ -630,11 +638,13 @@ def mountain_plot(
)

zorder = 1
X_ = X.copy()
if nan_bins is not None: X_[nan_bins] = np.nan
for i in range(len(NbyX)):
ax.plot(X, NbyX[i] + i + 1, c=c, zorder=zorder, lw=linewidth)
ax.plot(X_, NbyX[i] + i + 1, c=c, zorder=zorder, lw=linewidth)
zorder -= 0.01
ax.fill_between(
X, NbyX[i] + i + 1, i + 1, color=fc, zorder=zorder, alpha=0.8, linewidth=0
X_, NbyX[i] + i + 1, i + 1, color=fc, zorder=zorder, alpha=0.8, linewidth=0
)
zorder -= 0.01
ax.spines["left"].set_bounds(1, len(NbyX))
Expand Down

0 comments on commit 1390c86

Please sign in to comment.