diff --git a/examples/ppo_sentiments_llama.py b/examples/ppo_sentiments_llama.py index 5113f9071..50b50ce81 100644 --- a/examples/ppo_sentiments_llama.py +++ b/examples/ppo_sentiments_llama.py @@ -39,10 +39,10 @@ def llama_config(): trainer="AcceleratePPOTrainer", save_best=False, ), - model=ModelConfig(model_path="decapoda-research/llama-7b-hf", num_layers_unfrozen=2), - tokenizer=TokenizerConfig(tokenizer_path="decapoda-research/llama-7b-hf", truncation_side="right"), + model=ModelConfig(model_path="NousResearch/Llama-2-7b-hf", num_layers_unfrozen=2), + tokenizer=TokenizerConfig(tokenizer_path="NousResearch/Llama-2-7b-hf", truncation_side="right"), optimizer=OptimizerConfig( - name="adamw", kwargs=dict(lr=1.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + name="adamw", kwargs=dict(lr=0.00003, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) ), scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-5)), method=PPOConfig( @@ -50,7 +50,7 @@ def llama_config(): num_rollouts=128, chunk_size=128, ppo_epochs=4, - init_kl_coef=0.05, + init_kl_coef=0.001, target=6, horizon=10000, gamma=1, diff --git a/requirements.txt b/requirements.txt index 9770d2033..f64aa97f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -71,7 +71,7 @@ tokenizers==0.13.3 torch==2.0.0+cu117 torchtyping==0.1.4 tqdm==4.65.0 -transformers==4.28.1 +transformers==4.31.0 triton==2.0.0 tritonclient==2.33.0 typeguard==3.0.2