Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM error with PEFT LoRA on Llama2-7B #601

Open
arpaiva opened this issue Sep 20, 2024 · 1 comment
Open

OOM error with PEFT LoRA on Llama2-7B #601

arpaiva opened this issue Sep 20, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@arpaiva
Copy link

arpaiva commented Sep 20, 2024

🐛 Describe the bug

I'm trying to finetune Llama2-7B (to reproduce the experiments in a paper) using PEFT LoRA (0.124% of trainable params). However, this results in an out-of-memory (OOM) error on a 32GB V100 GPU. Using multiple GPUs or setting the Accelerate Deepspeed config to allow CPU offloading of optimizer states and weights doesn't help and yields the same OOM error.
It seems that once it sees that the initial model fits in one GPU, it assumes that everything else will also fit. Or maybe everything is indeed supposed to fit but it is wasting GPU memory somehow. Any clues on how to fix this?

Accelerate config:

compute_environment: LOCAL_MACHINE
distributed_type: DEEPSPEED
deepspeed_config:
  deepspeed_multinode_launcher: standard
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: cpu
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
  train_batch_size: NUM_PROCS
  train_micro_batch_size_per_gpu: 1
downcast_bf16: no
dynamo_config: {}
fsdp_config: {}
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: fp16
num_machines: 1
num_processes: NUM_PROCS
rdzv_backend: static
same_network: true
use_cpu: false

(NUM_PROCS gets replaced with the number of GPUs)

Error:

File "examples/mmlu_sft.py", line 335, in <module>
main(hparams)
File "examples/mmlu_sft.py", line 324, in main
trainer = trlx.train(
File "/scratch/repo/trlx/trlx.py", line 129, in train
trainer.learn()
File "/scratch/repo/trlx/trainer/accelerate_base_trainer.py", line 768, in learn
loss, stats = self.loss(microbatch)
File "/scratch/repo/trlx/trainer/accelerate_sft_trainer.py", line 70, in loss
loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=labels).loss
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1040, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0])
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 581, in forward
return model_forward(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 569, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.8/dist-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
return func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/peft/peft_model.py", line 918, in forward
return self.base_model(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/peft/tuners/tuners_utils.py", line 94, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 809, in forward
outputs = self.model(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 697, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 426, in forward
hidden_states = self.mlp(hidden_states)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 220, in forward
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 86.00 MiB (GPU 0; 31.74 GiB total capacity; 30.93 GiB already allocated; 59.31 MiB free; 31.02 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Which trlX version are you using?

0.7.0

Additional system and package information

Python=3.8.10; transformers=4.32.1; Linux x64

@arpaiva arpaiva added the bug Something isn't working label Sep 20, 2024
@arpaiva
Copy link
Author

arpaiva commented Sep 20, 2024

By the way, I've tried many, many variations on the Deepspeed settings (I believe I even tried not using accelerate when using only 1 GPU) and got OOM every time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant