Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change callback for AdversarialTrainer #626

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import abc
import dataclasses
import logging
from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload
from typing import Iterable, Iterator, List, Mapping, Optional, Type, overload

import numpy as np
import torch as th
import torch.utils.tensorboard as thboard
import tqdm
from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.sac import policies as sac_policies
from torch.nn import functional as F

Expand Down Expand Up @@ -421,7 +422,7 @@ def train_gen(
def train(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing - if you change the arguments, update of training_adversarial.py will also be needed

self,
total_timesteps: int,
callback: Optional[Callable[[int], None]] = None,
callback: Optional[List[BaseCallback]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to change the semantics of the argument here, or should we rather deprecate the feature (and introduce a different parameter for additional gen_callback)?

I think the suggestion in the original issue was to add a new gen_callback argument. (Btw, stable-baselines supports both CallbackList and list of callbacks if we wanted to be fancy)

) -> None:
"""Alternates between training the generator and discriminator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last part of description and finally a call to callback(round) is probably misleading now.

Expand All @@ -434,10 +435,15 @@ def train(
Args:
total_timesteps: An upper bound on the number of transitions to sample
from the environment during training.
callback: A function called at the end of every round which takes in a
single argument, the round number. Round numbers are in
`range(total_timesteps // self.gen_train_timesteps)`.
callback: List of stable_baslines3 callback to be passed to the policy
learning function.
"""
if callback is not None:
if self.gen_callback is None:
self.gen_callback = callback
else:
self.gen_callback = callback + [self.gen_callback]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can someone abuse the API by calling train() multiple times? If so, the value of self.gen_callback would contain nested list, which is not correct. Generally, the value of gen_callback is currently Optional[BaseCallback] and we shouldn't change the type to a list at runtime.

Perhaps it would be better to add an optional callback argument to train_gen(), merge callbacks there, and avoid the stateful change here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can the learn_kwargs argument from train_gen() be removed, as discussed in the original issue #607 ?


n_rounds = total_timesteps // self.gen_train_timesteps
assert n_rounds >= 1, (
"No updates (need at least "
Expand All @@ -450,8 +456,6 @@ def train(
with networks.training(self.reward_train):
# switch to training mode (affects dropout, normalization)
self.train_disc()
if callback:
callback(r)
self.logger.dump(self._global_step)

@overload
Expand Down