Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add AWR algorithm. #828

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Conversation

kxzxvbk
Copy link
Contributor

@kxzxvbk kxzxvbk commented Sep 10, 2024

Description

Implement the algorithm of AWR (languange model as the policy)

Related Issue

TODO

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@PaParaZz1 PaParaZz1 added the algo Add new algorithm or improve old one label Sep 11, 2024
@PaParaZz1 PaParaZz1 changed the title feature(whl): Add AWR algorithm. feature(whl): add AWR algorithm. Sep 11, 2024
ding/model/template/language_transformer.py Outdated Show resolved Hide resolved
ding/model/template/language_transformer.py Show resolved Hide resolved
ding/model/template/language_transformer.py Outdated Show resolved Hide resolved
ding/model/template/language_transformer.py Show resolved Hide resolved
ding/model/template/language_transformer.py Show resolved Hide resolved
ding/policy/prompt_awr.py Show resolved Hide resolved
ding/policy/prompt_awr.py Outdated Show resolved Hide resolved
ding/policy/prompt_awr.py Show resolved Hide resolved
ding/policy/prompt_awr.py Outdated Show resolved Hide resolved
ding/policy/prompt_awr.py Outdated Show resolved Hide resolved
@PaParaZz1
Copy link
Member

add the AWR algorithm into the table of README

ding/model/template/language_transformer.py Show resolved Hide resolved
ding/model/template/language_transformer.py Show resolved Hide resolved
@@ -18,13 +18,16 @@ class LanguageTransformer(nn.Module):
Interfaces:
``__init__``, ``forward``
"""
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']

def __init__(
self,
model_name: str = "bert-base-uncased",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments about why we use "bert-base-uncased" as default?


# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"]
for ii in range(len(cand_samples)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change 'ii' into more explainable name?

adv = torch.clamp(
return_ - batch['value'], min=self._cfg.learn.norm_range[0], max=self._cfg.learn.norm_range[1]
)
policy_loss = -(log_prob * torch.exp(adv / self._cfg.learn.beta)).mean()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments about the key part of advantage weighted regression

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the operation torch.exp(adv / self._cfg.learn.beta) stop the gradient flow?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this operation is computed on the batch data rather than the output data, thus it is no need to stop the gradient

if len(real_act.shape) == 1:
real_act = real_act.unsqueeze(-1)
# Calculate loss.
total_policy_loss, total_entropy_loss, total_value_loss = 0, 0, 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update comments

# (float) Coefficient that controls the exp scale in awr algorithm.
beta=1.0,
# (float) Weight of entropy regularization in the loss function.
entropy_weight=0.01,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we change it to a more generally applicable constant, such as 0.001?

adv = torch.clamp(
return_ - batch['value'], min=self._cfg.learn.norm_range[0], max=self._cfg.learn.norm_range[1]
)
policy_loss = -(log_prob * torch.exp(adv / self._cfg.learn.beta)).mean()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this operation is computed on the batch data rather than the output data, thus it is no need to stop the gradient

from easydict import EasyDict

tabmwp_prompt_pg_config = dict(
exp_name='tabmwp_prompt_pg_seed0',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish name, not pg in this file

@@ -36,10 +39,16 @@ def __init__(
- embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
- freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
defaults to be ``True``.
- hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \
correspond to the model you use. For bert-base-uncased, this value is 768.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be an indent

ding/model/template/language_transformer.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algo Add new algorithm or improve old one
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants