Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crash when using save_state with deepspeed: model.state_dict functions incompatible with new deepspeed. #596

Open
JohannesAck opened this issue Jul 11, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@JohannesAck
Copy link

🐛 Describe the bug

I've recently been using the code provided in https://github.com/tlc4418/llm_optimization, which in turn uses trlx.

In doing so I encountered a bug causing trlx to crash when trying to save, caused by a recent change in deepspeeed.

To reproduce, use this https://gist.github.com/JohannesAck/feb31ee5c491ca30771335296ec8b295 and start it with deepspeed by using accelerate launch with a config that enables deepspeed:

Traceback (most recent call last):
  File "/workspaces/llm_optimization/crash_example.py", line 111, in <module>
    main(hparams)
  File "/workspaces/llm_optimization/crash_example.py", line 101, in main
    trlx.train(
  File "/usr/local/lib/python3.10/dist-packages/trlx/trlx.py", line 142, in train
    trainer.learn()
  File "/usr/local/lib/python3.10/dist-packages/trlx/trainer/accelerate_base_trainer.py", line 598, in learn
    self.save(directory)
  File "/usr/local/lib/python3.10/dist-packages/trlx/trainer/accelerate_base_trainer.py", line 312, in save
    self.accelerator.save_state(dst_dir, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2944, in save_state
    model.save_checkpoint(output_dir, ckpt_id, **save_model_func_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 3105, in save_checkpoint
    self._save_checkpoint(save_dir,
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 3299, in _save_checkpoint
    module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2540, in module_state_dict
    sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
  File "/usr/local/lib/python3.10/dist-packages/trlx/models/modeling_ppo.py", line 460, in state_dict
    state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
TypeError: dict() got multiple values for keyword argument 'prefix'```

This is caused by this change in deepspeed microsoft/DeepSpeed#5408, that changes the call to state_dict to use a keyword instead of positional argument:

--- sd = self.module.state_dict(destination, prefix, keep_vars)
+++ sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)

TRLX however assumes that the argument will be passed

def state_dict(self, *args, heads_only=False, **kwargs):
"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))

In L359: dict(prefix="v_head.", **kwargs) becomes dict(prefix="v_head.", prefix="") and thus has two values for prefix and crashes.

Workaround:

Downgrade deepspeed to a version < 0.14.1:

pip install 'deepspeed<0.14.1'

I'm not sure what the proper solution here would be, just ignoring the prefix argument doesn't sound great either. One option might be to just ignore it if it's an empty string and raise an exception otherwise.

Hope this helps somebody!

Which trlX version are you using?

trlx=0.7.0

Additional system and package information

deepspeed=0.14.4

@JohannesAck JohannesAck added the bug Something isn't working label Jul 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant