-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rgb observation to dagger #802
base: master
Are you sure you want to change the base?
Changes from 10 commits
5182ecf
61d816b
c3331f6
fc9838d
fb9498b
e54c36c
ee04383
6e2218a
68fe666
9ad2aaf
67341d5
c497b56
d3f79bf
2de9e49
c47cca6
276294b
071d2a7
a2ccd7e
93baa2d
54f33af
c711abf
58a0d70
0f080d4
c4d3e11
b93294a
3f17ff2
0212e0e
070ebf9
657e17e
1f8c12a
bd70ecd
bec464c
bef19e6
9aaf73f
5d6aa77
86fbcf1
8d1e0d6
96978d5
e95df9d
161ec95
90bdf57
6aa25ff
bf48c76
4ce1b57
de1b1c8
1a1a458
f7866f4
daa838d
803eab0
0ac6f54
be9798b
c7e6809
82fb558
03714cc
187e881
ae96521
535a986
de027c4
4caa151
20c6f56
aaf94da
ef53690
df35e3a
a69e052
c6ed675
d7d8db1
5dd1699
8481890
b0d8d4b
12b60b2
afbbe46
5036fcf
3f23de1
a44b193
1073967
ae17588
f140035
9a00e68
052cf00
c63761c
468f621
3c6def5
68d1ac2
125d19d
f8ebbc4
33162b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
import gymnasium as gym | ||
import numpy as np | ||
import numpy.typing as npt | ||
import torch as th | ||
from gymnasium.core import Env | ||
from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper | ||
|
||
|
@@ -213,7 +212,7 @@ def step(self, action): | |
return obs, rew, terminated, truncated, info | ||
|
||
|
||
class HumanReadableWrapper(gym.Wrapper): | ||
class HumanReadableWrapper(gym.ObservationWrapper): | ||
"""Adds human-readable observation to `obs` at every step.""" | ||
|
||
def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"): | ||
|
@@ -235,30 +234,8 @@ def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"): | |
) | ||
self._original_obs_key = original_obs_key | ||
super().__init__(env) | ||
self._update_obs_space() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we still need the observation space update. From the Gynmasium docs:
|
||
|
||
def _update_obs_space(self): | ||
# need to reset before render. | ||
self.env.reset() | ||
example_rgb_obs = self.env.render() | ||
new_rgb_space = gym.spaces.Box( | ||
low=0, | ||
high=255, | ||
shape=example_rgb_obs.shape, | ||
dtype=np.uint8, | ||
) | ||
curr_sapce = self.observation_space | ||
if isinstance(curr_sapce, gym.spaces.Dict): | ||
curr_sapce.spaces[HR_OBS_KEY] = new_rgb_space | ||
else: | ||
self.observation_space = gym.spaces.Dict( | ||
{ | ||
HR_OBS_KEY: new_rgb_space, | ||
self._original_obs_key: curr_sapce, | ||
}, | ||
) | ||
|
||
def _add_hr_obs( | ||
def observation( | ||
self, | ||
obs: Union[np.ndarray, Dict[str, np.ndarray]], | ||
) -> Dict[str, np.ndarray]: | ||
|
@@ -284,51 +261,3 @@ def _add_hr_obs( | |
raise KeyError(f"{HR_OBS_KEY!r} already exists in observation dict") | ||
obs[HR_OBS_KEY] = self.env.render() # type: ignore[assignment] | ||
return obs | ||
|
||
def reset(self, **kwargs): | ||
obs, info = super().reset(**kwargs) | ||
return self._add_hr_obs(obs), info | ||
|
||
def step(self, action): | ||
obs, rew, terminated, truncated, info = self.env.step(action) | ||
return self._add_hr_obs(obs), rew, terminated, truncated, info | ||
|
||
|
||
def remove_rgb_obs_space(obs_space: gym.Space) -> gym.Space: | ||
"""Removes rgb observation space from the observation space.""" | ||
if not isinstance(obs_space, gym.spaces.Dict): | ||
return obs_space | ||
if HR_OBS_KEY not in obs_space.spaces: | ||
return obs_space | ||
if len(obs_space.keys()) == 1: | ||
raise ValueError( | ||
"Only human readable observation space exists, can't remove it", | ||
) | ||
# keeps the original obs_space unchanged in case it is used elsewhere. | ||
new_obs_space = gym.spaces.Dict(obs_space.spaces.copy()) | ||
del new_obs_space.spaces[HR_OBS_KEY] | ||
if len(new_obs_space.spaces) == 1: | ||
# unwrap dictionary structure | ||
return next(iter(new_obs_space.values())) | ||
return new_obs_space | ||
|
||
|
||
def remove_rgb_obs( | ||
obs: Union[Dict[str, np.ndarray], Dict[str, th.Tensor], np.ndarray, th.Tensor], | ||
) -> Union[Dict[str, np.ndarray], Dict[str, th.Tensor], np.ndarray, th.Tensor]: | ||
"""Removes rgb observation from the observation.""" | ||
if not isinstance(obs, dict): | ||
return obs | ||
if HR_OBS_KEY not in obs: | ||
return obs | ||
if len(obs) == 1: | ||
raise ValueError( | ||
"Only human readable observation exists, can't remove it", | ||
) | ||
# keeps the original observation unchanged in case it is used elsewhere. | ||
new_obs = obs.copy() | ||
del new_obs[HR_OBS_KEY] | ||
if len(new_obs) == 1: | ||
# unwrap dictionary structure | ||
return next(iter(new_obs.values())) # type: ignore[return-value] | ||
return new_obs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy-and-pasted comment doesn't make sense out of context (what is
self
here?)