Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the problem of the reset function of Memory corresponding to actor_critic_recurrent #35

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

thkkk
Copy link

@thkkk thkkk commented Jul 17, 2024

Original code in class Memory(torch.nn.Module): at rsl_rl/modules/actor_critic_recurrent.py

    def reset(self, dones=None):
        # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
        for hidden_state in self.hidden_states:
            hidden_state[..., dones, :] = 0.0

When I train PPO policy with num_envs=1 using ActorCriticRecurrent, I find a bug:

../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [127,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Traceback (most recent call last):
  File "rsl_rl/rsl_rl/runners/on_policy_runner.py", line 124, in learn
    self.alg.process_env_step(rewards, dones, infos)
  File "rsl_rl/rsl_rl/algorithms/ppo.py", line 95, in process_env_step
    self.actor_critic.reset(dones)
  File "rsl_rl/rsl_rl/modules/actor_critic_recurrent.py", line 54, in reset
    self.memory_a.reset(dones)
  File "rsl_rl/rsl_rl/modules/actor_critic_recurrent.py", line 102, in reset
    hidden_state[..., dones, :] = 0.0  # dones_envs_id
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

I modify the code and find that dones.max() >= hidden_state.size(-2)

    def reset(self, dones=None):
        # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
        # dones: (num_envs,), hidden_states: (num_layers, num_envs, hidden_size)
        print(f"dones: {dones},  hidden: {self.hidden_states[0].shape}")
        for hidden_state in self.hidden_states:
            assert dones.max() < hidden_state.size(-2), f"dones {dones} index out of range {hidden_state.shape}"
            hidden_state[..., dones, :] = 0.0

The logs( num_envs=1) are below:

dones: tensor([0], device='cuda:0'),  hidden: torch.Size([2, 1, 256])
dones: tensor([0], device='cuda:0'),  hidden: torch.Size([2, 1, 256])
dones: tensor([0], device='cuda:0'),  hidden: torch.Size([2, 1, 256])
dones: tensor([0], device='cuda:0'),  hidden: torch.Size([2, 1, 256])
dones: tensor([1], device='cuda:0'),  hidden: torch.Size([2, 1, 256])

rsl_rl/rsl_rl/modules/actor_critic_recurrent.py", line 101, in reset
    assert dones.max() < hidden_state.size(-2), f"dones {dones} index out of range {hidden_state.shape}"
AssertionError: dones tensor([1], device='cuda:0') index out of range torch.Size([2, 1, 256])

The logs( num_envs=4) are below, it will not result in error, but the index of hidden_state is not correct.

dones: tensor([0, 0, 0, 0], device='cuda:0'),  hidden: torch.Size([2, 4, 256])
dones: tensor([0, 0, 0, 0], device='cuda:0'),  hidden: torch.Size([2, 4, 256])
dones: tensor([0, 0, 0, 0], device='cuda:0'),  hidden: torch.Size([2, 4, 256])
dones: tensor([0, 0, 0, 0], device='cuda:0'),  hidden: torch.Size([2, 4, 256])
dones: tensor([0, 0, 0, 1], device='cuda:0'),  hidden: torch.Size([2, 4, 256])
dones: tensor([0, 0, 0, 1], device='cuda:0'),  hidden: torch.Size([2, 4, 256])

It can be found that the meaning of the elements in dones is whether each environment has ended. But what we need to reset are the ids of those ended environments.
Therefore, the correct code is to find the envs whose dones are True.

    def reset(self, dones=None):
        # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
        # dones: (num_envs,), hidden_states: (num_layers, num_envs, hidden_size)
        dones_envs_id = torch.where(dones)[0] if dones else None
        for hidden_state in self.hidden_states:
            hidden_state[..., dones_envs_id, :] = 0.0

I don't know what the corresponding behavior is when done==True, so by default, all the memory of all environments will be set to 0.

@dxyy1
Copy link

dxyy1 commented Sep 27, 2024

I also came into the same issue. I think we can simplify your fix further with:

 def reset(self, dones=None):
        for hidden_state in self.hidden_states:
            hidden_state[..., dones.bool(), :] = 0.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants