Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change callback for AdversarialTrainer #626

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

gunnxx
Copy link

@gunnxx gunnxx commented Nov 15, 2022

Changing the callback mechanism of AdversarialTrainer such that we can insert sb3.EvalCallback. See #607.

@@ -421,7 +422,7 @@ def train_gen(
def train(
self,
total_timesteps: int,
callback: Optional[Callable[[int], None]] = None,
callback: Optional[List[BaseCallback]] = None
) -> None:
"""Alternates between training the generator and discriminator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last part of description and finally a call to callback(round) is probably misleading now.

if self.gen_callback is None:
self.gen_callback = callback
else:
self.gen_callback = callback + [self.gen_callback]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can someone abuse the API by calling train() multiple times? If so, the value of self.gen_callback would contain nested list, which is not correct. Generally, the value of gen_callback is currently Optional[BaseCallback] and we shouldn't change the type to a list at runtime.

Perhaps it would be better to add an optional callback argument to train_gen(), merge callbacks there, and avoid the stateful change here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can the learn_kwargs argument from train_gen() be removed, as discussed in the original issue #607 ?

@@ -421,7 +422,7 @@ def train_gen(
def train(
self,
total_timesteps: int,
callback: Optional[Callable[[int], None]] = None,
callback: Optional[List[BaseCallback]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to change the semantics of the argument here, or should we rather deprecate the feature (and introduce a different parameter for additional gen_callback)?

I think the suggestion in the original issue was to add a new gen_callback argument. (Btw, stable-baselines supports both CallbackList and list of callbacks if we wanted to be fancy)

@@ -421,7 +422,7 @@ def train_gen(
def train(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing - if you change the arguments, update of training_adversarial.py will also be needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants