Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rgb observation to dagger #802

Open
wants to merge 86 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
5182ecf
first pass of dict obs functionality
NixGD Sep 13, 2023
61d816b
cleanup DictObs
NixGD Sep 13, 2023
c3331f6
add dict space to test_types.py, fix some problems
NixGD Sep 14, 2023
fc9838d
add dict-obs test for rollout
NixGD Sep 14, 2023
fb9498b
add bc.py test
NixGD Sep 14, 2023
e54c36c
cleanup
NixGD Sep 14, 2023
ee04383
small fixes
NixGD Sep 14, 2023
6e2218a
small fixes
NixGD Sep 14, 2023
68fe666
fix type error in interactive.py
NixGD Sep 14, 2023
9ad2aaf
fix introduced error in mce_irl.py
NixGD Sep 14, 2023
67341d5
fix minor ci complaint
NixGD Sep 14, 2023
c497b56
add basic dictobs tests
NixGD Sep 14, 2023
d3f79bf
change default bc policy for dict obs space
NixGD Sep 14, 2023
2de9e49
refine rollout.py typechecks, comments
NixGD Sep 14, 2023
c47cca6
check rollout produces dictobs of correct shape
NixGD Sep 14, 2023
276294b
cleanup types and dictobs helpers
NixGD Sep 14, 2023
071d2a7
clean useless lines
NixGD Sep 14, 2023
a2ccd7e
clean up print statements
NixGD Sep 14, 2023
93baa2d
fix typos
NixGD Sep 15, 2023
54f33af
assert matching keys in from_obs_list
NixGD Sep 15, 2023
c711abf
move maybe_wrap, clean rollout
NixGD Sep 15, 2023
58a0d70
change policy callable to take dict[str, np.ndarray] not dictobs
NixGD Sep 15, 2023
0f080d4
rollout info wrapper supports dictobs
NixGD Sep 15, 2023
c4d3e11
fix from_obs_list key consistency check
NixGD Sep 15, 2023
b93294a
xfail save/load tests with dictobs
NixGD Sep 15, 2023
3f17ff2
doc for dictobs wrapper
NixGD Sep 15, 2023
0212e0e
don't error on int observations
NixGD Sep 15, 2023
070ebf9
lint fixes
NixGD Sep 15, 2023
657e17e
cleanup bc test for dict obs
NixGD Sep 15, 2023
1f8c12a
cleanup bc.py unwrapping
NixGD Sep 15, 2023
bd70ecd
cleanup rollout.py
NixGD Sep 15, 2023
bec464c
cleanup dictobs interface
NixGD Sep 15, 2023
bef19e6
small cleanups
NixGD Sep 15, 2023
9aaf73f
coverage fixes, test fix
NixGD Sep 15, 2023
5d6aa77
adjust error types
NixGD Sep 15, 2023
86fbcf1
docstrings for type helpers
NixGD Sep 15, 2023
8d1e0d6
add dict obs space support for density
NixGD Sep 15, 2023
96978d5
fix typos
NixGD Sep 15, 2023
e95df9d
Adam suggestions from code review
NixGD Sep 16, 2023
161ec95
small changes for code review
NixGD Sep 16, 2023
90bdf57
fix docstring
NixGD Sep 16, 2023
6aa25ff
remove FloatReward
ZiyueWang25 Oct 2, 2023
bf48c76
Merge remote-tracking branch 'origin/master' into support-dict-obs-space
ZiyueWang25 Oct 2, 2023
4ce1b57
Fix test_bc
ZiyueWang25 Oct 2, 2023
de1b1c8
Turn off GPU finding to avoid using gpu device
ZiyueWang25 Oct 2, 2023
1a1a458
Check None to ensure __add__ can work
ZiyueWang25 Oct 2, 2023
f7866f4
fix docstring
ZiyueWang25 Oct 2, 2023
daa838d
bypass pytype and lint test
ZiyueWang25 Oct 2, 2023
803eab0
format with black
ZiyueWang25 Oct 2, 2023
0ac6f54
Test dict space in density algo
ZiyueWang25 Oct 2, 2023
be9798b
black format
ZiyueWang25 Oct 2, 2023
c7e6809
small fix
ZiyueWang25 Oct 2, 2023
82fb558
Add DictObs into test_wrappers
ZiyueWang25 Oct 3, 2023
03714cc
fix format
ZiyueWang25 Oct 3, 2023
187e881
minor fix
ZiyueWang25 Oct 3, 2023
ae96521
type and lint fix
ZiyueWang25 Oct 3, 2023
535a986
Add policy training test
ZiyueWang25 Oct 3, 2023
de027c4
suppress line too long lint check on a line
ZiyueWang25 Oct 3, 2023
4caa151
acts to obs for clarity
ZiyueWang25 Oct 3, 2023
20c6f56
Add HumanReadableWrapper
ZiyueWang25 Oct 3, 2023
aaf94da
adjust wrapper and not set render_mode inside
ZiyueWang25 Oct 3, 2023
ef53690
fix dict env observation space
ZiyueWang25 Oct 3, 2023
df35e3a
add RemoveHumanReadableWrapper and update ob space
ZiyueWang25 Oct 4, 2023
a69e052
Remove some unnecessary helper functions
ZiyueWang25 Oct 4, 2023
c6ed675
include rgb obs to dagger algo
ZiyueWang25 Oct 4, 2023
d7d8db1
add wrappers tests and fix linter and typing
ZiyueWang25 Oct 5, 2023
5dd1699
change ob to obs
ZiyueWang25 Oct 5, 2023
8481890
allow not only dict type obs in dagger
ZiyueWang25 Oct 5, 2023
b0d8d4b
fix lint and test
ZiyueWang25 Oct 5, 2023
12b60b2
fix type and test
ZiyueWang25 Oct 5, 2023
afbbe46
Merge branch 'master' of github.com:HumanCompatibleAI/imitation into …
ZiyueWang25 Oct 5, 2023
5036fcf
fix type
ZiyueWang25 Oct 5, 2023
3f23de1
resolve typing issue
ZiyueWang25 Oct 5, 2023
a44b193
Remove wrong type annotation in test
ZiyueWang25 Oct 5, 2023
1073967
Merge branch 'master' of github.com:HumanCompatibleAI/imitation into …
ZiyueWang25 Oct 5, 2023
ae17588
resolve conflict
ZiyueWang25 Oct 5, 2023
f140035
add policy wrapper
ZiyueWang25 Oct 6, 2023
9a00e68
small fix
ZiyueWang25 Oct 6, 2023
052cf00
fix the data and policy wrappers
ZiyueWang25 Oct 6, 2023
c63761c
Use ObservationWrapper
ZiyueWang25 Oct 6, 2023
468f621
update naming
ZiyueWang25 Oct 6, 2023
3c6def5
update tests
ZiyueWang25 Oct 6, 2023
68d1ac2
update tests
ZiyueWang25 Oct 6, 2023
125d19d
update demo
ZiyueWang25 Oct 6, 2023
f8ebbc4
rgb to hr
ZiyueWang25 Oct 6, 2023
33162b2
small fix
ZiyueWang25 Oct 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions examples/train_dagger_atari_interactive_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,42 @@

import gymnasium as gym
import numpy as np
from stable_baselines3.common import vec_env
import torch as th
from stable_baselines3.common import torch_layers, vec_env

from imitation.algorithms import bc, dagger
from imitation.data import wrappers
from imitation.policies import interactive
from imitation.data import wrappers as data_wrappers
from imitation.policies import base as policy_base
from imitation.policies import interactive, obs_update_wrapper


def lr_schedule(_: float):
# Set lr_schedule to max value to force error if policy.optimizer
# is used by mistake (should use self.optimizer instead).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy-and-pasted comment doesn't make sense out of context (what is self here?)

return th.finfo(th.float32).max


if __name__ == "__main__":
rng = np.random.default_rng(0)

env = gym.make("PongNoFrameskip-v4", render_mode="rgb_array")
env = wrappers.HumanReadableWrapper(env)
venv = vec_env.DummyVecEnv([lambda: env])
hr_env = data_wrappers.HumanReadableWrapper(env)
venv = vec_env.DummyVecEnv([lambda: hr_env])
venv.seed(0)

expert = interactive.AtariInteractivePolicy(venv)
policy = policy_base.FeedForward32Policy(
observation_space=env.observation_space,
action_space=env.action_space,
lr_schedule=lr_schedule,
features_extractor_class=torch_layers.FlattenExtractor,
)
wrapped_policy = obs_update_wrapper.RemoveHR(policy, lr_schedule=lr_schedule)

bc_trainer = bc.BC(
observation_space=venv.observation_space,
action_space=venv.action_space,
observation_space=env.observation_space,
action_space=env.action_space,
policy=wrapped_policy,
rng=rng,
)

Expand Down
20 changes: 7 additions & 13 deletions src/imitation/algorithms/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from stable_baselines3.common import policies, torch_layers, utils, vec_env

from imitation.algorithms import base as algo_base
from imitation.data import rollout, types, wrappers
from imitation.data import rollout, types
from imitation.policies import base as policy_base
from imitation.util import logger as imit_logger
from imitation.util import util
Expand Down Expand Up @@ -294,7 +294,7 @@ def __init__(
observation_space: the observation space of the environment.
action_space: the action space of the environment.
rng: the random state to use for the random number generator.
policy: a Stable Baselines3 policy for learning; if unspecified,
policy: a Stable Baselines3 policy; if unspecified,
defaults to `FeedForward32Policy`.
demonstrations: Demonstrations from an expert (optional). Transitions
expressed directly as a `types.TransitionsMinimal` object, a sequence
Expand Down Expand Up @@ -334,19 +334,18 @@ def __init__(
self._bc_logger = BCLogger(self.logger)

self.action_space = action_space
obs_space_without_rgb = wrappers.remove_rgb_obs_space(observation_space)
self.observation_space = obs_space_without_rgb
self.observation_space = observation_space

self.rng = rng

if policy is None:
extractor = (
torch_layers.CombinedExtractor
if isinstance(obs_space_without_rgb, gym.spaces.Dict)
if isinstance(observation_space, gym.spaces.Dict)
else torch_layers.FlattenExtractor
)
policy = policy_base.FeedForward32Policy(
observation_space=obs_space_without_rgb,
observation_space=observation_space,
action_space=action_space,
# Set lr_schedule to max value to force error if policy.optimizer
# is used by mistake (should use self.optimizer instead).
Expand All @@ -355,7 +354,7 @@ def __init__(
)
self._policy = policy.to(utils.get_device(device))
# TODO(adam): make policy mandatory and delete observation/action space params?
assert self.policy.observation_space == obs_space_without_rgb
assert self.policy.observation_space == self.observation_space
assert self.policy.action_space == self.action_space

if optimizer_kwargs:
Expand Down Expand Up @@ -492,13 +491,8 @@ def process_batch():
lambda x: util.safe_to_tensor(x, device=self.policy.device),
types.maybe_unwrap_dictobs(batch["obs"]),
)
obs_tensor_without_rgb = wrappers.remove_rgb_obs(obs_tensor)
acts = util.safe_to_tensor(batch["acts"], device=self.policy.device)
training_metrics = self.loss_calculator(
self.policy,
obs_tensor_without_rgb,
acts,
)
training_metrics = self.loss_calculator(self.policy, obs_tensor, acts)

# Renormalise the loss to be averaged over the whole
# batch size instead of the minibatch size.
Expand Down
52 changes: 3 additions & 49 deletions src/imitation/algorithms/dagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@

import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common import policies, utils, vec_env
from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.vec_env.base_vec_env import VecEnvStepReturn
from torch.utils import data as th_data

from imitation.algorithms import base, bc
from imitation.data import rollout, serialize, types, wrappers
from imitation.data import rollout, serialize, types
from imitation.util import logger as imit_logger
from imitation.util import util

Expand Down Expand Up @@ -306,26 +304,6 @@ class NeedsDemosException(Exception):
"""Signals demos need to be collected for current round before continuing."""


def _check_for_correct_spaces_with_rgb_env(
env_might_with_rgb: GymEnv,
obs_space: spaces.Space,
action_space: spaces.Space,
) -> None:
"""Checks that whether an environment has the same spaces as provided ones."""
if isinstance(obs_space, spaces.Dict):
assert wrappers.HR_OBS_KEY not in obs_space.spaces
env_obs_space = wrappers.remove_rgb_obs_space(env_might_with_rgb.observation_space)
if obs_space != env_obs_space:
raise ValueError(
f"Observation spaces do not match: obs {obs_space} != env {env_obs_space}",
)
env_action_space = env_might_with_rgb.action_space
if action_space != env_action_space:
raise ValueError(
f"Action spaces do not match: obs {action_space} != env {env_action_space}",
)


class DAggerTrainer(base.BaseImitationAlgorithm):
"""DAgger training class with low-level API suitable for interactive human feedback.

Expand Down Expand Up @@ -396,7 +374,7 @@ def __init__(
self._all_demos = []
self.rng = rng

_check_for_correct_spaces_with_rgb_env(
utils.check_for_correct_spaces(
self.venv,
bc_trainer.observation_space,
bc_trainer.action_space,
Expand Down Expand Up @@ -531,30 +509,6 @@ def extend_and_update(
logging.info(f"New round number is {self.round_num}")
return self.round_num

def _get_trainable_predict_fn(
self,
) -> Callable[[Union[Dict[str, np.ndarray], np.ndarray]], np.ndarray]:
"""Returns a function that uses `bc_trainer.policy` to predict observations.

Since bc_trainer.policy doesn't accept RGB observations, this function removes
The RGB observation part, if any, before passing the observation to prediction.

Returns:
A function that accepts a dictionary observation and returns a numpy array
of actions.
"""

def remove_rgb_and_predict(
obs: Union[Dict[str, np.ndarray], np.ndarray],
) -> np.ndarray:
obs_without_rgb = wrappers.remove_rgb_obs(obs)
assert isinstance(obs_without_rgb, (np.ndarray, dict))
fn = self.bc_trainer.policy.predict
# the Dict[str, Tensor] type seems hard to exclude from type annotation.
return fn(obs_without_rgb)[0] # type: ignore[arg-type]

return remove_rgb_and_predict

def create_trajectory_collector(self) -> InteractiveTrajectoryCollector:
"""Create trajectory collector to extend current round's demonstration set.

Expand All @@ -567,7 +521,7 @@ def create_trajectory_collector(self) -> InteractiveTrajectoryCollector:
beta = self.beta_schedule(self.round_num)
collector = InteractiveTrajectoryCollector(
venv=self.venv,
get_robot_acts=self._get_trainable_predict_fn(),
get_robot_acts=lambda obs: self.bc_trainer.policy.predict(obs)[0],
beta=beta,
save_dir=save_dir,
rng=self.rng,
Expand Down
75 changes: 2 additions & 73 deletions src/imitation/data/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import gymnasium as gym
import numpy as np
import numpy.typing as npt
import torch as th
from gymnasium.core import Env
from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper

Expand Down Expand Up @@ -213,7 +212,7 @@ def step(self, action):
return obs, rew, terminated, truncated, info


class HumanReadableWrapper(gym.Wrapper):
class HumanReadableWrapper(gym.ObservationWrapper):
"""Adds human-readable observation to `obs` at every step."""

def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"):
Expand All @@ -235,30 +234,8 @@ def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"):
)
self._original_obs_key = original_obs_key
super().__init__(env)
self._update_obs_space()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need the observation space update. From the Gynmasium docs:

The transformation defined in that method must be reflected by the env observation space. Otherwise, you need to specify the new observation space of the wrapper by setting self.observation_space in the init() method of your wrapper.


def _update_obs_space(self):
# need to reset before render.
self.env.reset()
example_rgb_obs = self.env.render()
new_rgb_space = gym.spaces.Box(
low=0,
high=255,
shape=example_rgb_obs.shape,
dtype=np.uint8,
)
curr_sapce = self.observation_space
if isinstance(curr_sapce, gym.spaces.Dict):
curr_sapce.spaces[HR_OBS_KEY] = new_rgb_space
else:
self.observation_space = gym.spaces.Dict(
{
HR_OBS_KEY: new_rgb_space,
self._original_obs_key: curr_sapce,
},
)

def _add_hr_obs(
def observation(
self,
obs: Union[np.ndarray, Dict[str, np.ndarray]],
) -> Dict[str, np.ndarray]:
Expand All @@ -284,51 +261,3 @@ def _add_hr_obs(
raise KeyError(f"{HR_OBS_KEY!r} already exists in observation dict")
obs[HR_OBS_KEY] = self.env.render() # type: ignore[assignment]
return obs

def reset(self, **kwargs):
obs, info = super().reset(**kwargs)
return self._add_hr_obs(obs), info

def step(self, action):
obs, rew, terminated, truncated, info = self.env.step(action)
return self._add_hr_obs(obs), rew, terminated, truncated, info


def remove_rgb_obs_space(obs_space: gym.Space) -> gym.Space:
"""Removes rgb observation space from the observation space."""
if not isinstance(obs_space, gym.spaces.Dict):
return obs_space
if HR_OBS_KEY not in obs_space.spaces:
return obs_space
if len(obs_space.keys()) == 1:
raise ValueError(
"Only human readable observation space exists, can't remove it",
)
# keeps the original obs_space unchanged in case it is used elsewhere.
new_obs_space = gym.spaces.Dict(obs_space.spaces.copy())
del new_obs_space.spaces[HR_OBS_KEY]
if len(new_obs_space.spaces) == 1:
# unwrap dictionary structure
return next(iter(new_obs_space.values()))
return new_obs_space


def remove_rgb_obs(
obs: Union[Dict[str, np.ndarray], Dict[str, th.Tensor], np.ndarray, th.Tensor],
) -> Union[Dict[str, np.ndarray], Dict[str, th.Tensor], np.ndarray, th.Tensor]:
"""Removes rgb observation from the observation."""
if not isinstance(obs, dict):
return obs
if HR_OBS_KEY not in obs:
return obs
if len(obs) == 1:
raise ValueError(
"Only human readable observation exists, can't remove it",
)
# keeps the original observation unchanged in case it is used elsewhere.
new_obs = obs.copy()
del new_obs[HR_OBS_KEY]
if len(new_obs) == 1:
# unwrap dictionary structure
return next(iter(new_obs.values())) # type: ignore[return-value]
return new_obs
3 changes: 0 additions & 3 deletions src/imitation/policies/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ def _choose_action(
if self.clear_screen_on_query:
util.clear_screen()

if isinstance(obs, dict):
raise ValueError("Dictionary observations are not supported here")

context = self._render(obs)
key = self._get_input_key()
self._clean_up(context)
Expand Down
Loading
Loading