Skip to content

Commit

Permalink
llama2 sentiments
Browse files Browse the repository at this point in the history
  • Loading branch information
Duy Phung committed Jul 25, 2023
1 parent 288d4cb commit ae8ad21
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions examples/ppo_sentiments_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ def llama_config():
trainer="AcceleratePPOTrainer",
save_best=False,
),
model=ModelConfig(model_path="decapoda-research/llama-7b-hf", num_layers_unfrozen=2),
tokenizer=TokenizerConfig(tokenizer_path="decapoda-research/llama-7b-hf", truncation_side="right"),
model=ModelConfig(model_path="NousResearch/Llama-2-7b-hf", num_layers_unfrozen=2),
tokenizer=TokenizerConfig(tokenizer_path="NousResearch/Llama-2-7b-hf", truncation_side="right"),
optimizer=OptimizerConfig(
name="adamw", kwargs=dict(lr=1.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
name="adamw", kwargs=dict(lr=0.00003, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-5)),
method=PPOConfig(
name="PPOConfig",
num_rollouts=128,
chunk_size=128,
ppo_epochs=4,
init_kl_coef=0.05,
init_kl_coef=0.001,
target=6,
horizon=10000,
gamma=1,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ tokenizers==0.13.3
torch==2.0.0+cu117
torchtyping==0.1.4
tqdm==4.65.0
transformers==4.28.1
transformers==4.31.0
triton==2.0.0
tritonclient==2.33.0
typeguard==3.0.2
Expand Down

0 comments on commit ae8ad21

Please sign in to comment.