Skip to content

Commit

Permalink
Fix LLaMA example (LLaMA 2) (#539)
Browse files Browse the repository at this point in the history
* llama2 sentiments

* upgrade accelerate

* hyper params change

* fix(ppo_sentiments_llama): reduce `total_steps` for sample quality

---------

Co-authored-by: Duy Phung <[email protected]>
Co-authored-by: Duy Phung <[email protected]>
Co-authored-by: maxreciprocate <[email protected]>
  • Loading branch information
4 people committed Jul 31, 2023
1 parent 6f7f59d commit acd0a41
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions examples/ppo_sentiments_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,26 @@ def llama_config():
train=TrainConfig(
seq_length=1024,
epochs=100,
total_steps=10000,
total_steps=400,
batch_size=32,
checkpoint_interval=10000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="AcceleratePPOTrainer",
save_best=False,
),
model=ModelConfig(model_path="decapoda-research/llama-7b-hf", num_layers_unfrozen=2),
tokenizer=TokenizerConfig(tokenizer_path="decapoda-research/llama-7b-hf", truncation_side="right"),
model=ModelConfig(model_path="NousResearch/Llama-2-7b-hf", num_layers_unfrozen=2),
tokenizer=TokenizerConfig(tokenizer_path="NousResearch/Llama-2-7b-hf", truncation_side="right"),
optimizer=OptimizerConfig(
name="adamw", kwargs=dict(lr=1.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
name="adamw", kwargs=dict(lr=1e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-5)),
method=PPOConfig(
name="PPOConfig",
num_rollouts=128,
chunk_size=128,
ppo_epochs=4,
init_kl_coef=0.05,
init_kl_coef=0.001,
target=6,
horizon=10000,
gamma=1,
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cu117
accelerate==0.18.0
accelerate==0.21.0
aiohttp==3.8.4
aiosignal==1.3.1
appdirs==1.4.4
Expand Down Expand Up @@ -71,7 +71,7 @@ tokenizers==0.13.3
torch==2.0.0+cu117
torchtyping==0.1.4
tqdm==4.65.0
transformers==4.28.1
transformers==4.31.0
triton==2.0.0
tritonclient==2.33.0
typeguard==3.0.2
Expand Down

0 comments on commit acd0a41

Please sign in to comment.