Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid infinite recursion in plot_rate_map() #95

Merged
merged 1 commit into from
Nov 30, 2023

Conversation

colleenjg
Copy link
Contributor

Enabled recurrent inputs to be identified when adding inputs to FeedForward Neurons. These inputs will be ignored when a groundtruth rate map is plotted, to avoid infinite recursion.

@TomGeorge1234
Copy link
Collaborator

TomGeorge1234 commented Nov 29, 2023

Oh...I love this!!!!

Do you think it would be cleaner if, instead get_state() instead received a max_recursion_depth argument defaulting to 1. Then the internal call does something like:

if recursion_depth > 0 and inputlayer['recurrent']:
    continue
if evaluate_at == "last": 
    I = inputlayer["layer"].firingrate
else: #kick the can down the road, except for expired recurrent loops
    w = inputlayer["w"]
    I = inputlayer["layer"].get_state(evaluate_at, recursion_depth=recursion_depth-1, **kwargs)
    V += np.matmul(w, I)

Then you wouldn't even need a new plot_rate_map() function. And users could control the recursion depth if they wanted to by FFL.plot_rate_map(recursion_depth=42).

This has the benefits that recursion can never be infinite. It is slightly different from yours though because here the recursive loop is still used exactly once for rate maps. We could either avoid this by (i) setting recursion_depth=0 as default (then we'd need to put the evaluate_at="last" outside the if test) (ii) write a (now even simpler) plot_rate_map() wrapper which is just super.plot_rate_map(recursion_depth=0) or (iii) write an new update() wrapper which forces recursion >0 for online update. This last on feels a bit weird imo but would work.

What do you think - you've probably thought about this a lot more than I have.

@colleenjg
Copy link
Contributor Author

That's even better! Because my version really prevents any contribution from recursive inputs from shaping the rate maps.

I think what you're proposing looks like a clean solution! The only oddity I noticed was that recursion_depth doesn't quite describe what the variable does, in my view, since this variable is decremented even if you don't have exact recursion (i.e., layer calling itself). What I mean is that if you have a loop with three nodes, each node in the loop will decrement the depth by 1, instead of one pass through the full loop decreasing it by 1.

So perhaps we can call it max_depth_if_recursion. Or is that an annoyingly long name? And I would suggest to only this decrement this variable it if the inputlayer is recurrent.

pass_max_depth_if_recursion = max_depth_if_recursion
if inputlayer['recurrent']:
    if max_depth_if_recursion == 0
        continue
    else:
        pass_max_depth_if_recursion = max_depth_if_recursion - 1

if evaluate_at == "last": 
    I = inputlayer["layer"].firingrate
else: #kick the can down the road, except for expired recurrent loops
    w = inputlayer["w"]
    I = inputlayer["layer"].get_state(evaluate_at, max_depth_if_recursion=pass_max_depth_if_recursion, **kwargs)
    V += np.matmul(w, I)

Do you think this would work?

@colleenjg
Copy link
Contributor Author

I'm realizing I didn't follow the end of your comment. What if we put the default as max_depth_if_recursion=None, and so it's ignored, unless it's set, and we add plot_rate_maps(max_depth_if_recursion=None), as you suggest? I'll push a new version to clarify.

@TomGeorge1234
Copy link
Collaborator

TomGeorge1234 commented Nov 29, 2023

My mistake, definitely should only be decremented once per loop, good spot.

You made me realise there's an important distinction between recursion (get_state() calling another get_state(), which strictly applies to all inputs which are FeedForwardLayers) and recurrence (inputs which eventually circle back on themselves). The correct name is something like max_recursion_depth_for_recurrent_inputs but that's ridiculous. What about max_recurrence (which clarifies it's the variable which applies to recurrent loops in the graph structure).

`max_recurrence`: 1 # The maximum number of time get_state() recursively calls recurrent inputs (prevents recursion error when plotting rate maps).  

Perhaps more readable is:

# Skip this input if you're past the recurrence limit
if inputlayer['recurrent'] and max_depth_if_recursion>0:
    continue

# Get layer input, either from its current firingrate or from recursively calling Input.get_state(). 
if evaluate_at == "last": 
    I = inputlayer["layer"].firingrate
else: # kick the can down the road
    w = inputlayer["w"]
    I = inputlayer["layer"].get_state(evaluate_at, max_recurrence = max_recurrence-inputlayer['recurrent'], **kwargs) # decreases the recursion depth iff the layer input is flagged as recursive. 
    V += np.matmul(w, I)

Thoughts?

We should also ad a comment into add_input() clarifying that only one node in the recursive loop must be flagged as recursive.

@colleenjg
Copy link
Contributor Author

colleenjg commented Nov 29, 2023

Yeah, that makes sense!

if inputlayer['recurrent'] and max_depth_if_recursion > 0:
    continue

should be

if inputlayer['recurrent'] and max_depth_if_recursion <= 0:
    continue

Are you ok with the default being None, in which case this is ignored? Kind of forcing the user to reflect on what depth they want or to get a recursion error?

@TomGeorge1234
Copy link
Collaborator

Sure (we can always change the default down the line if we change our mind)...but how will it be ignored? None-1 will throw an error.

@colleenjg
Copy link
Contributor Author

Yeah, my version has a bit more lines, where None is checked

@colleenjg
Copy link
Contributor Author

colleenjg commented Nov 29, 2023

Ok, I just pushed a new version that seems to be working, on my end!

Edit: Sorry for the numerous force pushes, I kept finding typos. This should work now.

@colleenjg colleenjg force-pushed the cjg-dev branch 3 times, most recently from 652f7c8 to 6f73978 Compare November 29, 2023 17:12
…en evaluating get_state(), e.g., when plotting rate maps.
Copy link
Collaborator

@TomGeorge1234 TomGeorge1234 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect, great work!

@TomGeorge1234 TomGeorge1234 merged commit 16375e1 into RatInABox-Lab:dev Nov 30, 2023
3 checks passed
@TomGeorge1234
Copy link
Collaborator

v1.11.2 fixes this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants